@@ -21,10 +21,10 @@ use multiversion::multiversion;
2121use  std:: ops:: Add ; 
2222
2323use  crate :: array:: { 
24-     Array ,   BooleanArray ,   GenericBinaryArray ,   GenericStringArray ,   OffsetSizeTrait , 
25-     PrimitiveArray , 
24+     as_primitive_array ,   Array ,   ArrayAccessor ,   ArrayIter ,   BooleanArray , 
25+     GenericBinaryArray ,   GenericStringArray ,   OffsetSizeTrait ,   PrimitiveArray , 
2626} ; 
27- use  crate :: datatypes:: { ArrowNativeType ,  ArrowNumericType } ; 
27+ use  crate :: datatypes:: { ArrowNativeType ,  ArrowNumericType ,   DataType } ; 
2828
2929/// Generic test for NaN, the optimizer should be able to remove this for integer types. 
3030#[ inline]  
@@ -185,6 +185,37 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
185185} 
186186
187187/// Returns the sum of values in the array. 
188+ pub  fn  sum_dyn < T ,  A :  ArrayAccessor < Item  = T :: Native > > ( array :  A )  -> Option < T :: Native > 
189+ where 
190+     T :  ArrowNumericType , 
191+     T :: Native :  Add < Output  = T :: Native > , 
192+ { 
193+     match  array. data_type ( )  { 
194+         DataType :: Dictionary ( _,  _)  => { 
195+             let  null_count = array. null_count ( ) ; 
196+ 
197+             if  null_count == array. len ( )  { 
198+                 return  None ; 
199+             } 
200+ 
201+             let  iter = ArrayIter :: new ( array) ; 
202+             let  sum = iter
203+                 . into_iter ( ) 
204+                 . fold ( T :: default_value ( ) ,  |accumulator,  value| { 
205+                     if  let  Some ( value)  = value { 
206+                         accumulator + value
207+                     }  else  { 
208+                         accumulator
209+                     } 
210+                 } ) ; 
211+ 
212+             Some ( sum) 
213+         } 
214+         _ => sum :: < T > ( as_primitive_array ( & array) ) , 
215+     } 
216+ } 
217+ 
218+ /// Returns the sum of values in the primitive array. 
188219/// 
189220/// Returns `None` if the array is empty or only contains null values. 
190221#[ cfg( not( feature = "simd" ) ) ]  
@@ -583,7 +614,7 @@ mod simd {
583614    } 
584615} 
585616
586- /// Returns the sum of values in the array. 
617+ /// Returns the sum of values in the primitive  array. 
587618/// 
588619/// Returns `None` if the array is empty or only contains null values. 
589620#[ cfg( feature = "simd" ) ]  
@@ -625,6 +656,7 @@ mod tests {
625656    use  super :: * ; 
626657    use  crate :: array:: * ; 
627658    use  crate :: compute:: add; 
659+     use  crate :: datatypes:: { Int32Type ,  Int8Type } ; 
628660
629661    #[ test]  
630662    fn  test_primitive_array_sum ( )  { 
@@ -1003,4 +1035,27 @@ mod tests {
10031035        assert_eq ! ( Some ( true ) ,  min_boolean( & a) ) ; 
10041036        assert_eq ! ( Some ( true ) ,  max_boolean( & a) ) ; 
10051037    } 
1038+ 
1039+     #[ test]  
1040+     fn  test_sum_dyn ( )  { 
1041+         let  values = Int8Array :: from_iter_values ( [ 10_i8 ,  11 ,  12 ,  13 ,  14 ,  15 ,  16 ,  17 ] ) ; 
1042+         let  keys = Int8Array :: from_iter_values ( [ 2_i8 ,  3 ,  4 ] ) ; 
1043+ 
1044+         let  dict_array = DictionaryArray :: try_new ( & keys,  & values) . unwrap ( ) ; 
1045+         let  array = dict_array. downcast_dict :: < Int8Array > ( ) . unwrap ( ) ; 
1046+         assert_eq ! ( 39 ,  sum_dyn:: <Int8Type ,  _>( array) . unwrap( ) ) ; 
1047+ 
1048+         let  a = Int32Array :: from ( vec ! [ 1 ,  2 ,  3 ,  4 ,  5 ] ) ; 
1049+         assert_eq ! ( 15 ,  sum_dyn:: <Int32Type ,  _>( & a) . unwrap( ) ) ; 
1050+ 
1051+         let  keys = Int8Array :: from ( vec ! [ Some ( 2_i8 ) ,  None ,  Some ( 4 ) ] ) ; 
1052+         let  dict_array = DictionaryArray :: try_new ( & keys,  & values) . unwrap ( ) ; 
1053+         let  array = dict_array. downcast_dict :: < Int8Array > ( ) . unwrap ( ) ; 
1054+         assert_eq ! ( 26 ,  sum_dyn:: <Int8Type ,  _>( array) . unwrap( ) ) ; 
1055+ 
1056+         let  keys = Int8Array :: from ( vec ! [ None ,  None ,  None ] ) ; 
1057+         let  dict_array = DictionaryArray :: try_new ( & keys,  & values) . unwrap ( ) ; 
1058+         let  array = dict_array. downcast_dict :: < Int8Array > ( ) . unwrap ( ) ; 
1059+         assert ! ( sum_dyn:: <Int8Type ,  _>( array) . is_none( ) ) ; 
1060+     } 
10061061} 
0 commit comments