@@ -2,6 +2,8 @@ use ndarray::prelude::*;
22use ndarray:: Data ;
33use num_traits:: { Float , FromPrimitive } ;
44
5+ /// Extension trait for ArrayBase providing functions
6+ /// to compute different correlation measures.
57pub trait CorrelationExt < A , S >
68where
79 S : Data < Elem = A > ,
@@ -11,13 +13,13 @@ where
1113 ///
1214 /// Let `(r, o)` be the shape of `M`:
1315 /// - `r` is the number of random variables;
14- /// - `o` is the number of observations we have collected
16+ /// - `o` is the number of observations we have collected
1517 /// for each random variable.
16- ///
17- /// Every column in `M` is an experiment: a single observation for each
18+ ///
19+ /// Every column in `M` is an experiment: a single observation for each
1820 /// random variable.
1921 /// Each row in `M` contains all the observations for a certain random variable.
20- ///
22+ ///
2123 /// The parameter `ddof` specifies the "delta degrees of freedom". For
2224 /// example, to calculate the population covariance, use `ddof = 0`, or to
2325 /// calculate the sample covariance (unbiased estimate), use `ddof = 1`.
3739 /// x̅ = ― ∑ xᵢ
3840 /// n i=1
3941 /// ```
40- /// and similarly for ̅y.
42+ /// and similarly for ̅y.
4143 ///
4244 /// **Panics** if `ddof` is greater than or equal to the number of
4345 /// observations, if the number of observations is zero and division by
@@ -56,11 +58,65 @@ where
5658 /// [2., 4., 6.]]);
5759 /// let covariance = a.cov(1.);
5860 /// assert_eq!(
59- /// covariance,
61+ /// covariance,
6062 /// aview2(&[[4., 4.], [4., 4.]])
6163 /// );
6264 /// ```
63- fn cov ( & self , ddof : A ) -> Array2 < A >
65+ fn cov ( & self , ddof : A ) -> Array2 < A >
66+ where
67+ A : Float + FromPrimitive ;
68+
69+ /// Return the [Pearson correlation coefficients](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
70+ /// for a 2-dimensional array of observations `M`.
71+ ///
72+ /// Let `(r, o)` be the shape of `M`:
73+ /// - `r` is the number of random variables;
74+ /// - `o` is the number of observations we have collected
75+ /// for each random variable.
76+ ///
77+ /// Every column in `M` is an experiment: a single observation for each
78+ /// random variable.
79+ /// Each row in `M` contains all the observations for a certain random variable.
80+ ///
81+ /// The Pearson correlation coefficient of two random variables is defined as:
82+ ///
83+ /// ```text
84+ /// cov(X, Y)
85+ /// rho(X, Y) = ――――――――――――
86+ /// std(X)std(Y)
87+ /// ```
88+ ///
89+ /// Let `R` be the matrix returned by this function. Then
90+ /// ```text
91+ /// R_ij = rho(X_i, X_j)
92+ /// ```
93+ ///
94+ /// **Panics** if `M` is empty, if the type cast of `n_observations`
95+ /// from `usize` to `A` fails or if the standard deviation of one of the random
96+ ///
97+ /// # Example
98+ ///
99+ /// variables is zero and division by zero panics for type A.
100+ /// ```
101+ /// extern crate ndarray;
102+ /// extern crate ndarray_stats;
103+ /// use ndarray::arr2;
104+ /// use ndarray_stats::CorrelationExt;
105+ ///
106+ /// let a = arr2(&[[1., 3., 5.],
107+ /// [2., 4., 6.]]);
108+ /// let corr = a.pearson_correlation();
109+ /// assert!(
110+ /// corr.all_close(
111+ /// &arr2(&[
112+ /// [1., 1.],
113+ /// [1., 1.],
114+ /// ]),
115+ /// 1e-7
116+ /// )
117+ /// );
118+ /// ```
119+ fn pearson_correlation ( & self ) -> Array2 < A >
64120 where
65121 A : Float + FromPrimitive ;
66122}
75131 {
76132 let observation_axis = Axis ( 1 ) ;
77133 let n_observations = A :: from_usize ( self . len_of ( observation_axis) ) . unwrap ( ) ;
78- let dof =
134+ let dof =
79135 if ddof >= n_observations {
80136 panic ! ( "`ddof` needs to be strictly smaller than the \
81137 number of observations provided for each \
@@ -88,16 +144,33 @@ where
88144 let covariance = denoised. dot ( & denoised. t ( ) ) ;
89145 covariance. mapv_into ( |x| x / dof)
90146 }
147+
148+ fn pearson_correlation ( & self ) -> Array2 < A >
149+ where
150+ A : Float + FromPrimitive ,
151+ {
152+ let observation_axis = Axis ( 1 ) ;
153+ // The ddof value doesn't matter, as long as we use the same one
154+ // for computing covariance and standard deviation
155+ // We choose -1 to avoid panicking when we only have one
156+ // observation per random variable (or no observations at all)
157+ let ddof = -A :: one ( ) ;
158+ let cov = self . cov ( ddof) ;
159+ let std = self . std_axis ( observation_axis, ddof) . insert_axis ( observation_axis) ;
160+ let std_matrix = std. dot ( & std. t ( ) ) ;
161+ // element-wise division
162+ cov / std_matrix
163+ }
91164}
92165
93166#[ cfg( test) ]
94- mod tests {
167+ mod cov_tests {
95168 use super :: * ;
96169 use rand;
97170 use rand:: distributions:: Range ;
98171 use ndarray_rand:: RandomExt ;
99172
100- quickcheck ! {
173+ quickcheck ! {
101174 fn constant_random_variables_have_zero_covariance_matrix( value: f64 ) -> bool {
102175 let n_random_variables = 3 ;
103176 let n_observations = 4 ;
@@ -112,21 +185,21 @@ mod tests {
112185 let n_random_variables = 3 ;
113186 let n_observations = 4 ;
114187 let a = Array :: random(
115- ( n_random_variables, n_observations) ,
188+ ( n_random_variables, n_observations) ,
116189 Range :: new( -bound. abs( ) , bound. abs( ) )
117190 ) ;
118191 let covariance = a. cov( 1. ) ;
119192 covariance. all_close( & covariance. t( ) , 1e-8 )
120193 }
121194 }
122-
195+
123196 #[ test]
124197 #[ should_panic]
125198 fn test_invalid_ddof ( ) {
126199 let n_random_variables = 3 ;
127200 let n_observations = 4 ;
128201 let a = Array :: random (
129- ( n_random_variables, n_observations) ,
202+ ( n_random_variables, n_observations) ,
130203 Range :: new ( 0. , 10. )
131204 ) ;
132205 let invalid_ddof = ( n_observations as f64 ) + rand:: random :: < f64 > ( ) . abs ( ) ;
@@ -200,4 +273,79 @@ mod tests {
200273 )
201274 ) ;
202275 }
203- }
276+ }
277+
278+ #[ cfg( test) ]
279+ mod pearson_correlation_tests {
280+ use super :: * ;
281+ use rand:: distributions:: Range ;
282+ use ndarray_rand:: RandomExt ;
283+
284+ quickcheck ! {
285+ fn output_matrix_is_symmetric( bound: f64 ) -> bool {
286+ let n_random_variables = 3 ;
287+ let n_observations = 4 ;
288+ let a = Array :: random(
289+ ( n_random_variables, n_observations) ,
290+ Range :: new( -bound. abs( ) , bound. abs( ) )
291+ ) ;
292+ let pearson_correlation = a. pearson_correlation( ) ;
293+ pearson_correlation. all_close( & pearson_correlation. t( ) , 1e-8 )
294+ }
295+
296+ fn constant_random_variables_have_nan_correlation( value: f64 ) -> bool {
297+ let n_random_variables = 3 ;
298+ let n_observations = 4 ;
299+ let a = Array :: from_elem( ( n_random_variables, n_observations) , value) ;
300+ let pearson_correlation = a. pearson_correlation( ) ;
301+ pearson_correlation. iter( ) . map( |x| x. is_nan( ) ) . fold( true , |acc, flag| acc & flag)
302+ }
303+ }
304+
305+ #[ test]
306+ fn test_zero_variables ( ) {
307+ let a = Array2 :: < f32 > :: zeros ( ( 0 , 2 ) ) ;
308+ let pearson_correlation = a. pearson_correlation ( ) ;
309+ assert_eq ! ( pearson_correlation. shape( ) , & [ 0 , 0 ] ) ;
310+ }
311+
312+ #[ test]
313+ fn test_zero_observations ( ) {
314+ let a = Array2 :: < f32 > :: zeros ( ( 2 , 0 ) ) ;
315+ let pearson = a. pearson_correlation ( ) ;
316+ pearson. mapv ( |x| x. is_nan ( ) ) ;
317+ }
318+
319+ #[ test]
320+ fn test_zero_variables_zero_observations ( ) {
321+ let a = Array2 :: < f32 > :: zeros ( ( 0 , 0 ) ) ;
322+ let pearson = a. pearson_correlation ( ) ;
323+ assert_eq ! ( pearson. shape( ) , & [ 0 , 0 ] ) ;
324+ }
325+
326+ #[ test]
327+ fn test_for_random_array ( ) {
328+ let a = array ! [
329+ [ 0.16351516 , 0.56863268 , 0.16924196 , 0.72579120 ] ,
330+ [ 0.44342453 , 0.19834387 , 0.25411802 , 0.62462382 ] ,
331+ [ 0.97162731 , 0.29958849 , 0.17338142 , 0.80198342 ] ,
332+ [ 0.91727132 , 0.79817799 , 0.62237124 , 0.38970998 ] ,
333+ [ 0.26979716 , 0.20887228 , 0.95454999 , 0.96290785 ]
334+ ] ;
335+ let numpy_corrcoeff = array ! [
336+ [ 1. , 0.38089376 , 0.08122504 , -0.59931623 , 0.1365648 ] ,
337+ [ 0.38089376 , 1. , 0.80918429 , -0.52615195 , 0.38954398 ] ,
338+ [ 0.08122504 , 0.80918429 , 1. , 0.07134906 , -0.17324776 ] ,
339+ [ -0.59931623 , -0.52615195 , 0.07134906 , 1. , -0.8743213 ] ,
340+ [ 0.1365648 , 0.38954398 , -0.17324776 , -0.8743213 , 1. ]
341+ ] ;
342+ assert_eq ! ( a. ndim( ) , 2 ) ;
343+ assert ! (
344+ a. pearson_correlation( ) . all_close(
345+ & numpy_corrcoeff,
346+ 1e-7
347+ )
348+ ) ;
349+ }
350+
351+ }
0 commit comments