@@ -565,11 +565,12 @@ where
565
565
let sums = std:: mem:: take ( & mut self . sums ) ;
566
566
let nulls = self . null_state . build ( ) ;
567
567
568
+ assert_eq ! ( nulls. len( ) , sums. len( ) ) ;
568
569
assert_eq ! ( counts. len( ) , sums. len( ) ) ;
569
570
570
571
// don't evaluate averages with null inputs to avoid errors on null values
571
- let array : PrimitiveArray < T > = if let Some ( nulls ) = nulls . as_ref ( ) {
572
- assert_eq ! ( nulls. len ( ) , sums . len ( ) ) ;
572
+
573
+ let array : PrimitiveArray < T > = if nulls. null_count ( ) > 0 {
573
574
let mut builder = PrimitiveBuilder :: < T > :: with_capacity ( nulls. len ( ) ) ;
574
575
let iter = sums. into_iter ( ) . zip ( counts. into_iter ( ) ) . zip ( nulls. iter ( ) ) ;
575
576
@@ -587,7 +588,7 @@ where
587
588
. zip ( counts. into_iter ( ) )
588
589
. map ( |( sum, count) | ( self . avg_fn ) ( sum, count) )
589
590
. collect :: < Result < Vec < _ > > > ( ) ?;
590
- PrimitiveArray :: new ( averages. into ( ) , nulls) // no copy
591
+ PrimitiveArray :: new ( averages. into ( ) , Some ( nulls) ) // no copy
591
592
} ;
592
593
593
594
// fix up decimal precision and scale for decimals
@@ -598,9 +599,9 @@ where
598
599
599
600
// return arrays for sums and counts
600
601
fn state ( & mut self ) -> Result < Vec < ArrayRef > > {
601
- let nulls = self . null_state . build ( ) ;
602
+ let nulls = Some ( self . null_state . build ( ) ) ;
602
603
let counts = std:: mem:: take ( & mut self . counts ) ;
603
- let counts = UInt64Array :: from ( counts) ; // zero copy
604
+ let counts = UInt64Array :: new ( counts. into ( ) , nulls . clone ( ) ) ; // zero copy
604
605
605
606
let sums = std:: mem:: take ( & mut self . sums ) ;
606
607
let sums = PrimitiveArray :: < T > :: new ( sums. into ( ) , nulls) ; // zero copy
0 commit comments