@@ -33,37 +33,45 @@ def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: DTypeLike) -> np.ndarray:
33
33
def _get_mean_var (
34
34
X : _SupportedArray , * , axis : Literal [0 , 1 ] = 0
35
35
) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
36
- if isinstance (X , sparse .spmatrix ):
37
- mean , var = sparse_mean_variance_axis (X , axis = axis )
38
- var *= X .shape [axis ] / (X .shape [axis ] - 1 )
36
+ if isinstance (X , np .ndarray ):
37
+ mean , var = _compute_mean_var (X , axis = axis , dtype = np .float64 )
39
38
else :
40
- mean ,var = _compute_mean_var (X ,axis = axis ,dtype = np .float64 )
39
+ if isinstance (X , sparse .spmatrix ):
40
+ mean , var = sparse_mean_variance_axis (X , axis = axis )
41
+ else :
42
+ mean = axis_mean (X , axis = axis , dtype = np .float64 )
43
+ mean_sq = axis_mean (elem_mul (X , X ), axis = axis , dtype = np .float64 )
44
+ var = mean_sq - mean ** 2
45
+ # enforce R convention (unbiased estimator) for variance
46
+ var *= X .shape [axis ] / (X .shape [axis ] - 1 )
41
47
return mean , var
42
48
43
- @numba .njit (cache = True ,parallel = True )
49
+
50
+ @numba .njit (cache = True , parallel = True )
44
51
def _compute_mean_var (
45
- X : _SupportedArray , axis : Literal [0 , 1 ] = 0 ,dtype : DTypeLike | None = None
52
+ X : _SupportedArray , axis : Literal [0 , 1 ] = 0 , dtype : DTypeLike | None = None
46
53
) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
47
54
nthr = numba .get_num_threads ()
48
- axis_i = 1 if axis == 0 else 0
49
- s = np .zeros ((nthr ,X .shape [axis_i ]),dtype = dtype )
50
- ss = np .zeros ((nthr ,X .shape [axis_i ]),dtype = dtype )
51
- mean = np .zeros (X .shape [axis_i ],dtype = dtype )
52
- #std=np.zeros(X.shape[axis_i],dtype=dtype)
53
- var = np .zeros (X .shape [axis_i ],dtype = dtype )
55
+ axis_i = 1 if axis == 0 else 0
56
+ s = np .zeros ((nthr , X .shape [axis_i ]), dtype = dtype )
57
+ ss = np .zeros ((nthr , X .shape [axis_i ]), dtype = dtype )
58
+ mean = np .zeros (X .shape [axis_i ], dtype = dtype )
59
+ # std=np.zeros(X.shape[axis_i],dtype=dtype)
60
+ var = np .zeros (X .shape [axis_i ], dtype = dtype )
54
61
n = X .shape [axis ]
55
62
for i in numba .prange (nthr ):
56
- for r in range (i ,n , nthr ):
63
+ for r in range (i , n , nthr ):
57
64
for c in range (X .shape [axis_i ]):
58
- v = X [r ,c ] if axis == 0 else X [c ,r ]
59
- s [i ,c ] += v
60
- ss [i ,c ] += v * v
65
+ v = X [r , c ] if axis == 0 else X [c , r ]
66
+ s [i , c ] += v
67
+ ss [i , c ] += v * v
61
68
for c in numba .prange (X .shape [axis_i ]):
62
- s0 = s [:,c ].sum ()
63
- mean [c ] = s0 / n
64
- var [c ] = (ss [:,c ].sum () - s0 * s0 / n )/ (n - 1 )
65
- #std[c]=np.sqrt(var[c])
66
- return mean ,var
69
+ s0 = s [:, c ].sum ()
70
+ mean [c ] = s0 / n
71
+ var [c ] = (ss [:, c ].sum () - s0 * s0 / n ) / (n - 1 )
72
+ # std[c]=np.sqrt(var[c])
73
+ return mean , var
74
+
67
75
68
76
def sparse_mean_variance_axis (mtx : sparse .spmatrix , axis : int ):
69
77
"""
0 commit comments