@@ -677,33 +677,31 @@ def mean(self, axis, **kwargs):
677677
678678 skipna = kwargs .get ("skipna" , True )
679679
680- def map_apply_fn (ser , ** kwargs ):
681- try :
682- sum_result = ser .sum (skipna = skipna )
683- count_result = ser .count ()
684- except TypeError :
685- return None
686- else :
687- return (sum_result , count_result )
688-
689- def reduce_apply_fn (ser , ** kwargs ):
690- sum_result = ser .apply (lambda x : x [0 ]).sum (skipna = skipna )
691- count_result = ser .apply (lambda x : x [1 ]).sum (skipna = skipna )
692- return sum_result / count_result
680+ # TODO-FIX: this function may work incorrectly with user-defined "numeric" values.
681+ # Since `count(numeric_only=True)` discards all unknown "numeric" types, we can get incorrect
682+ # divisor inside the reduce function.
683+ def map_fn (df , ** kwargs ):
684+ result = pandas .DataFrame (
685+ {
686+ "sum" : df .sum (axis = axis , skipna = skipna ),
687+ "count" : df .count (axis = axis , numeric_only = True ),
688+ }
689+ )
690+ return result if axis else result .T
693691
694692 def reduce_fn (df , ** kwargs ):
695- df .dropna (axis = 1 , inplace = True , how = "any" )
696- return build_applyier (reduce_apply_fn , axis = axis )(df )
697-
698- def build_applyier (func , ** applyier_kwargs ):
699- def applyier (df , ** kwargs ):
700- result = df .apply (func , ** applyier_kwargs )
701- return result .set_axis (df .axes [axis ^ 1 ], axis = 0 )
693+ sum_cols = df ["sum" ] if axis else df .loc ["sum" ]
694+ count_cols = df ["count" ] if axis else df .loc ["count" ]
702695
703- return applyier
696+ if not isinstance (sum_cols , pandas .Series ):
697+ # If we got `NaN` as the result of the sum in any axis partition,
698+ # then we must consider the whole sum as `NaN`, so setting `skipna=False`
699+ sum_cols = sum_cols .sum (axis = axis , skipna = False )
700+ count_cols = count_cols .sum (axis = axis , skipna = False )
701+ return sum_cols / count_cols
704702
705703 return MapReduceFunction .register (
706- build_applyier ( map_apply_fn , axis = axis , result_type = "reduce" ) ,
704+ map_fn ,
707705 reduce_fn ,
708706 preserve_index = (kwargs .get ("numeric_only" ) is not None ),
709707 )(self , axis = axis , ** kwargs )
0 commit comments