From 6cb51b2a28281d1ebf2ac742605f1aeb569c2ac7 Mon Sep 17 00:00:00 2001 From: Peter Fischer Date: Tue, 21 Jul 2015 16:03:02 -0400 Subject: [PATCH 1/3] Initialize ARPACK eigsh `v0 = random_state.rand(M.shape[0])` leads to an initial residual vector in ARPACK which is all positive. However, this is not the absolute or squared residual, but a true difference. Thus, it is better to initialize with `v0=random_state.uniform(-1, 1, M.shape[0])` to have an equally distributed sign. This is the way that ARPACK initializes the residuals. The effect of the previous initialization is that eigsh frequently does not converge to the correct eigenvalues, e.g. negative eigenvalues for s.p.d. matrix, which leads to an incorrect null-space. - initialized all occurences of sklearn.utils.arpack.eigsh the same way it would be initialzed by ARPACK - regression test to test behavior of new initialization --- doc/whats_new.rst | 4 ++++ sklearn/cluster/bicluster.py | 15 +++++++++---- sklearn/decomposition/kernel_pca.py | 11 ++++++++-- sklearn/manifold/locally_linear.py | 3 ++- sklearn/manifold/spectral_embedding_.py | 3 ++- sklearn/utils/tests/test_utils.py | 28 ++++++++++++++++++++++++- 6 files changed, 55 insertions(+), 9 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index ad99711e50e3a..d45f1c54a0f10 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -43,6 +43,10 @@ Bug fixes - Fixed bug in :func:`manifold.spectral_embedding` where diagonal of unnormalized Laplacian matrix was incorrectly set to 1. By `Peter Fischer`_. + - Fixed incorrect initialization of :func:`utils.arpack.eigsh` on all + occurrences. Affects :class:`cluster.SpectralBiclustering`, + :class:`decomposition.KernelPCA`, :class:`manifold.LocallyLinearEmbedding`, + and :class:`manifold.SpectralEmbedding`. By `Peter Fischer`_. API changes summary ------------------- diff --git a/sklearn/cluster/bicluster.py b/sklearn/cluster/bicluster.py index 16a45c94715f4..682d6e4e200f0 100644 --- a/sklearn/cluster/bicluster.py +++ b/sklearn/cluster/bicluster.py @@ -14,6 +14,7 @@ from . import KMeans, MiniBatchKMeans from ..base import BaseEstimator, BiclusterMixin from ..externals import six +from ..utils import check_random_state from ..utils.arpack import eigsh, svds from ..utils.extmath import (make_nonnegative, norm, randomized_svd, @@ -140,12 +141,18 @@ def _svd(self, array, n_components, n_discard): # some eigenvalues of A * A.T are negative, causing # sqrt() to be np.nan. This causes some vectors in vt # to be np.nan. - _, v = eigsh(safe_sparse_dot(array.T, array), - ncv=self.n_svd_vecs) + A = safe_sparse_dot(array.T, array) + random_state = check_random_state(self.random_state) + # initialize with [-1,1] as in ARPACK + v0 = random_state.uniform(-1, 1, A.shape[0]) + _, v = eigsh(A, ncv=self.n_svd_vecs, v0=v0) vt = v.T if np.any(np.isnan(u)): - _, u = eigsh(safe_sparse_dot(array, array.T), - ncv=self.n_svd_vecs) + A = safe_sparse_dot(array, array.T) + random_state = check_random_state(self.random_state) + # initialize with [-1,1] as in ARPACK + v0 = random_state.uniform(-1, 1, A.shape[0]) + _, u = eigsh(A, ncv=self.n_svd_vecs, v0=v0) assert_all_finite(u) assert_all_finite(vt) diff --git a/sklearn/decomposition/kernel_pca.py b/sklearn/decomposition/kernel_pca.py index 4131bf2fc642e..c7d657ea41475 100644 --- a/sklearn/decomposition/kernel_pca.py +++ b/sklearn/decomposition/kernel_pca.py @@ -6,6 +6,7 @@ import numpy as np from scipy import linalg +from ..utils import check_random_state from ..utils.arpack import eigsh from ..utils.validation import check_is_fitted from ..exceptions import NotFittedError @@ -103,7 +104,8 @@ class KernelPCA(BaseEstimator, TransformerMixin): def __init__(self, n_components=None, kernel="linear", gamma=None, degree=3, coef0=1, kernel_params=None, alpha=1.0, fit_inverse_transform=False, eigen_solver='auto', - tol=0, max_iter=None, remove_zero_eig=False): + tol=0, max_iter=None, remove_zero_eig=False, + random_state=None): if fit_inverse_transform and kernel == 'precomputed': raise ValueError( "Cannot fit_inverse_transform with a precomputed kernel.") @@ -120,6 +122,7 @@ def __init__(self, n_components=None, kernel="linear", self.tol = tol self.max_iter = max_iter self._centerer = KernelCenterer() + self.random_state = random_state @property def _pairwise(self): @@ -158,10 +161,14 @@ def _fit_transform(self, K): self.lambdas_, self.alphas_ = linalg.eigh( K, eigvals=(K.shape[0] - n_components, K.shape[0] - 1)) elif eigen_solver == 'arpack': + random_state = check_random_state(self.random_state) + # initialize with [-1,1] as in ARPACK + v0 = random_state.uniform(-1, 1, K.shape[0]) self.lambdas_, self.alphas_ = eigsh(K, n_components, which="LA", tol=self.tol, - maxiter=self.max_iter) + maxiter=self.max_iter, + v0=v0) # sort eigenvectors in descending order indices = self.lambdas_.argsort()[::-1] diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py index 9eed4432405cb..9ddc7008d2e27 100644 --- a/sklearn/manifold/locally_linear.py +++ b/sklearn/manifold/locally_linear.py @@ -148,7 +148,8 @@ def null_space(M, k, k_skip=1, eigen_solver='arpack', tol=1E-6, max_iter=100, if eigen_solver == 'arpack': random_state = check_random_state(random_state) - v0 = random_state.rand(M.shape[0]) + # initialize with [-1,1] as in ARPACK + v0 = random_state.uniform(-1, 1, M.shape[0]) try: eigen_values, eigen_vectors = eigsh(M, k + k_skip, sigma=0.0, tol=tol, maxiter=max_iter, diff --git a/sklearn/manifold/spectral_embedding_.py b/sklearn/manifold/spectral_embedding_.py index e4166d66f2e66..0d011d4c54592 100644 --- a/sklearn/manifold/spectral_embedding_.py +++ b/sklearn/manifold/spectral_embedding_.py @@ -259,9 +259,10 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None, # We are computing the opposite of the laplacian inplace so as # to spare a memory allocation of a possibly very large array laplacian *= -1 + v0 = random_state.uniform(-1, 1, laplacian.shape[0]) lambdas, diffusion_map = eigsh(laplacian, k=n_components, sigma=1.0, which='LM', - tol=eigen_tol) + tol=eigen_tol, v0=v0) embedding = diffusion_map.T[n_components::-1] * dd except RuntimeError: # When submatrices are exactly singular, an LU decomposition diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 52b05dc19ec0c..a28809c787033 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -3,11 +3,13 @@ import numpy as np import scipy.sparse as sp from scipy.linalg import pinv2 +from scipy.linalg import eigh from itertools import chain from sklearn.utils.testing import (assert_equal, assert_raises, assert_true, assert_almost_equal, assert_array_equal, - SkipTest, assert_raises_regex) + SkipTest, assert_raises_regex, + assert_greater_equal) from sklearn.utils import check_random_state from sklearn.utils import deprecated @@ -18,7 +20,9 @@ from sklearn.utils import shuffle from sklearn.utils import gen_even_slices from sklearn.utils.extmath import pinvh +from sklearn.utils.arpack import eigsh from sklearn.utils.mocking import MockDataFrame +from sklearn.utils.graph import graph_laplacian def test_make_rng(): @@ -126,6 +130,28 @@ def test_pinvh_simple_complex(): assert_almost_equal(np.dot(a, a_pinv), np.eye(3)) +def test_arpack_eigsh_initialization(): + ''' + Non-regression test that shows null-space computation is better with + initialization of eigsh from [-1,1] instead of [0,1] + ''' + random_state = check_random_state(42) + + A = random_state.rand(50, 50) + A = np.dot(A.T, A) # create s.p.d. matrix + A = graph_laplacian(A) + k = 5 + + # Test if eigsh is working correctly + # New initialization [-1,1] (as in original ARPACK) + # Was [0,1] before, with which this test could fail + v0 = random_state.uniform(-1,1, A.shape[0]) + w, _ = eigsh(A, k=k, sigma=0.0, v0=v0) + + # Eigenvalues of s.p.d. matrix should be nonnegative, w[0] is smallest + assert_greater_equal(w[0], 0) + + def test_column_or_1d(): EXAMPLES = [ ("binary", ["spam", "egg", "spam"]), From 9f8e062571addccd6eeef8277a28d36ce1649cb2 Mon Sep 17 00:00:00 2001 From: yanlend Date: Fri, 24 Jul 2015 11:48:35 -0400 Subject: [PATCH 2/3] fix docs and doctests for KernelPCA --- doc/modules/pipeline.rst | 2 +- sklearn/decomposition/kernel_pca.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/modules/pipeline.rst b/doc/modules/pipeline.rst index b4709e7008cc1..0f9c1f635995d 100644 --- a/doc/modules/pipeline.rst +++ b/doc/modules/pipeline.rst @@ -156,7 +156,7 @@ and ``value`` is an estimator object:: n_components=None, whiten=False)), ('kernel_pca', KernelPCA(alpha=1.0, coef0=1, degree=3, eigen_solver='auto', fit_inverse_transform=False, gamma=None, kernel='linear', kernel_params=None, max_iter=None, - n_components=None, remove_zero_eig=False, tol=0))], + n_components=None, random_state=None, remove_zero_eig=False, tol=0))], transformer_weights=None) Like pipelines, feature unions have a shorthand constructor called diff --git a/sklearn/decomposition/kernel_pca.py b/sklearn/decomposition/kernel_pca.py index c7d657ea41475..dc83e10ec527b 100644 --- a/sklearn/decomposition/kernel_pca.py +++ b/sklearn/decomposition/kernel_pca.py @@ -77,6 +77,10 @@ class KernelPCA(BaseEstimator, TransformerMixin): When n_components is None, this parameter is ignored and components with zero eigenvalues are removed regardless. + random_state : int seed, RandomState instance, or None, default : None + A pseudo random number generator used for the initialization of the + residuals when eigen_solver == 'arpack'. + Attributes ---------- From 8ff339eb5ba4a1353cd228f577e59da6865c066a Mon Sep 17 00:00:00 2001 From: Peter Fischer Date: Fri, 31 Jul 2015 08:27:28 -0400 Subject: [PATCH 3/3] Update test_utils.py Comment instead of docstring to follow sklearn convention. Regression test made easier by adding small value to main diagonal of s.p.d. matrix --- sklearn/utils/tests/test_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index a28809c787033..54c18aec27daf 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -131,15 +131,13 @@ def test_pinvh_simple_complex(): def test_arpack_eigsh_initialization(): - ''' - Non-regression test that shows null-space computation is better with - initialization of eigsh from [-1,1] instead of [0,1] - ''' + # Non-regression test that shows null-space computation is better with + # initialization of eigsh from [-1,1] instead of [0,1] random_state = check_random_state(42) A = random_state.rand(50, 50) A = np.dot(A.T, A) # create s.p.d. matrix - A = graph_laplacian(A) + A = graph_laplacian(A) + 1e-7 * np.identity(A.shape[0]) k = 5 # Test if eigsh is working correctly