Skip to content

Commit

Permalink
Merge pull request scikit-learn#5012 from yanlend/patch-2
Browse files Browse the repository at this point in the history
[MRG+1] Initialize ARPACK eigsh
  • Loading branch information
GaelVaroquaux committed Oct 23, 2015
2 parents 1863895 + 8ff339e commit 2492dd0
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 10 deletions.
2 changes: 1 addition & 1 deletion doc/modules/pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ and ``value`` is an estimator object::
n_components=None, whiten=False)), ('kernel_pca', KernelPCA(alpha=1.0,
coef0=1, degree=3, eigen_solver='auto', fit_inverse_transform=False,
gamma=None, kernel='linear', kernel_params=None, max_iter=None,
n_components=None, remove_zero_eig=False, tol=0))],
n_components=None, random_state=None, remove_zero_eig=False, tol=0))],
transformer_weights=None)

Like pipelines, feature unions have a shorthand constructor called
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ Bug fixes
- Fixed bug in :func:`manifold.spectral_embedding` where diagonal of unnormalized
Laplacian matrix was incorrectly set to 1. By `Peter Fischer`_.

- Fixed incorrect initialization of :func:`utils.arpack.eigsh` on all
occurrences. Affects :class:`cluster.SpectralBiclustering`,
:class:`decomposition.KernelPCA`, :class:`manifold.LocallyLinearEmbedding`,
and :class:`manifold.SpectralEmbedding`. By `Peter Fischer`_.

API changes summary
-------------------
Expand Down
15 changes: 11 additions & 4 deletions sklearn/cluster/bicluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from . import KMeans, MiniBatchKMeans
from ..base import BaseEstimator, BiclusterMixin
from ..externals import six
from ..utils import check_random_state
from ..utils.arpack import eigsh, svds

from ..utils.extmath import (make_nonnegative, norm, randomized_svd,
Expand Down Expand Up @@ -140,12 +141,18 @@ def _svd(self, array, n_components, n_discard):
# some eigenvalues of A * A.T are negative, causing
# sqrt() to be np.nan. This causes some vectors in vt
# to be np.nan.
_, v = eigsh(safe_sparse_dot(array.T, array),
ncv=self.n_svd_vecs)
A = safe_sparse_dot(array.T, array)
random_state = check_random_state(self.random_state)
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, A.shape[0])
_, v = eigsh(A, ncv=self.n_svd_vecs, v0=v0)
vt = v.T
if np.any(np.isnan(u)):
_, u = eigsh(safe_sparse_dot(array, array.T),
ncv=self.n_svd_vecs)
A = safe_sparse_dot(array, array.T)
random_state = check_random_state(self.random_state)
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, A.shape[0])
_, u = eigsh(A, ncv=self.n_svd_vecs, v0=v0)

assert_all_finite(u)
assert_all_finite(vt)
Expand Down
15 changes: 13 additions & 2 deletions sklearn/decomposition/kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from scipy import linalg

from ..utils import check_random_state
from ..utils.arpack import eigsh
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
Expand Down Expand Up @@ -76,6 +77,10 @@ class KernelPCA(BaseEstimator, TransformerMixin):
When n_components is None, this parameter is ignored and components
with zero eigenvalues are removed regardless.
random_state : int seed, RandomState instance, or None, default : None
A pseudo random number generator used for the initialization of the
residuals when eigen_solver == 'arpack'.
Attributes
----------
Expand Down Expand Up @@ -103,7 +108,8 @@ class KernelPCA(BaseEstimator, TransformerMixin):
def __init__(self, n_components=None, kernel="linear",
gamma=None, degree=3, coef0=1, kernel_params=None,
alpha=1.0, fit_inverse_transform=False, eigen_solver='auto',
tol=0, max_iter=None, remove_zero_eig=False):
tol=0, max_iter=None, remove_zero_eig=False,
random_state=None):
if fit_inverse_transform and kernel == 'precomputed':
raise ValueError(
"Cannot fit_inverse_transform with a precomputed kernel.")
Expand All @@ -120,6 +126,7 @@ def __init__(self, n_components=None, kernel="linear",
self.tol = tol
self.max_iter = max_iter
self._centerer = KernelCenterer()
self.random_state = random_state

@property
def _pairwise(self):
Expand Down Expand Up @@ -158,10 +165,14 @@ def _fit_transform(self, K):
self.lambdas_, self.alphas_ = linalg.eigh(
K, eigvals=(K.shape[0] - n_components, K.shape[0] - 1))
elif eigen_solver == 'arpack':
random_state = check_random_state(self.random_state)
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, K.shape[0])
self.lambdas_, self.alphas_ = eigsh(K, n_components,
which="LA",
tol=self.tol,
maxiter=self.max_iter)
maxiter=self.max_iter,
v0=v0)

# sort eigenvectors in descending order
indices = self.lambdas_.argsort()[::-1]
Expand Down
3 changes: 2 additions & 1 deletion sklearn/manifold/locally_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def null_space(M, k, k_skip=1, eigen_solver='arpack', tol=1E-6, max_iter=100,

if eigen_solver == 'arpack':
random_state = check_random_state(random_state)
v0 = random_state.rand(M.shape[0])
# initialize with [-1,1] as in ARPACK
v0 = random_state.uniform(-1, 1, M.shape[0])
try:
eigen_values, eigen_vectors = eigsh(M, k + k_skip, sigma=0.0,
tol=tol, maxiter=max_iter,
Expand Down
3 changes: 2 additions & 1 deletion sklearn/manifold/spectral_embedding_.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,10 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
# We are computing the opposite of the laplacian inplace so as
# to spare a memory allocation of a possibly very large array
laplacian *= -1
v0 = random_state.uniform(-1, 1, laplacian.shape[0])
lambdas, diffusion_map = eigsh(laplacian, k=n_components,
sigma=1.0, which='LM',
tol=eigen_tol)
tol=eigen_tol, v0=v0)
embedding = diffusion_map.T[n_components::-1] * dd
except RuntimeError:
# When submatrices are exactly singular, an LU decomposition
Expand Down
26 changes: 25 additions & 1 deletion sklearn/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
import scipy.sparse as sp
from scipy.linalg import pinv2
from scipy.linalg import eigh
from itertools import chain

from sklearn.utils.testing import (assert_equal, assert_raises, assert_true,
assert_almost_equal, assert_array_equal,
SkipTest, assert_raises_regex)
SkipTest, assert_raises_regex,
assert_greater_equal)

from sklearn.utils import check_random_state
from sklearn.utils import deprecated
Expand All @@ -18,7 +20,9 @@
from sklearn.utils import shuffle
from sklearn.utils import gen_even_slices
from sklearn.utils.extmath import pinvh
from sklearn.utils.arpack import eigsh
from sklearn.utils.mocking import MockDataFrame
from sklearn.utils.graph import graph_laplacian


def test_make_rng():
Expand Down Expand Up @@ -126,6 +130,26 @@ def test_pinvh_simple_complex():
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))


def test_arpack_eigsh_initialization():
# Non-regression test that shows null-space computation is better with
# initialization of eigsh from [-1,1] instead of [0,1]
random_state = check_random_state(42)

A = random_state.rand(50, 50)
A = np.dot(A.T, A) # create s.p.d. matrix
A = graph_laplacian(A) + 1e-7 * np.identity(A.shape[0])
k = 5

# Test if eigsh is working correctly
# New initialization [-1,1] (as in original ARPACK)
# Was [0,1] before, with which this test could fail
v0 = random_state.uniform(-1,1, A.shape[0])
w, _ = eigsh(A, k=k, sigma=0.0, v0=v0)

# Eigenvalues of s.p.d. matrix should be nonnegative, w[0] is smallest
assert_greater_equal(w[0], 0)


def test_column_or_1d():
EXAMPLES = [
("binary", ["spam", "egg", "spam"]),
Expand Down

0 comments on commit 2492dd0

Please sign in to comment.