Skip to content

Commit 97717ba

Browse files
veneamueller
authored andcommitted
Use pinvh wherever it helps in the codebase.
Use pinvh in plot_sparse_recovery example Use pinvh in bayes.py Use pinvh in GMM and DPGMM
1 parent 6d66d0f commit 97717ba

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

examples/linear_model/plot_sparse_recovery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,15 @@
5353
from sklearn.preprocessing import Scaler
5454
from sklearn.metrics import auc, precision_recall_curve
5555
from sklearn.ensemble import ExtraTreesRegressor
56+
from sklearn.utils.extmath import pinvh
5657

5758

5859
def mutual_incoherence(X_relevant, X_irelevant):
5960
"""Mutual incoherence, as defined by formula (26a) of [Wainwright2006].
6061
"""
6162
projector = np.dot(
6263
np.dot(X_irelevant.T, X_relevant),
63-
linalg.pinv(np.dot(X_relevant.T, X_relevant))
64+
pinvh(np.dot(X_relevant.T, X_relevant))
6465
)
6566
return np.max(np.abs(projector).sum(axis=1))
6667

sklearn/linear_model/bayes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .base import LinearModel
1313
from ..base import RegressorMixin
14-
from ..utils.extmath import fast_logdet
14+
from ..utils.extmath import fast_logdet, pinvh
1515
from ..utils import check_arrays
1616

1717

@@ -382,7 +382,7 @@ def fit(self, X, y):
382382
### Iterative procedure of ARDRegression
383383
for iter_ in range(self.n_iter):
384384
### Compute mu and sigma (using Woodbury matrix identity)
385-
sigma_ = linalg.pinv(np.eye(n_samples) / alpha_ +
385+
sigma_ = pinvh(np.eye(n_samples) / alpha_ +
386386
np.dot(X[:, keep_lambda] *
387387
np.reshape(1. / lambda_[keep_lambda], [1, -1]),
388388
X[:, keep_lambda].T))

sklearn/mixture/dpgmm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from scipy.spatial.distance import cdist
1717

1818
from ..utils import check_random_state
19-
from ..utils.extmath import norm, logsumexp
19+
from ..utils.extmath import norm, logsumexp, pinvh
2020
from .. import cluster
2121
from .gmm import GMM
2222

@@ -215,7 +215,7 @@ def _get_precisions(self):
215215
return [self.precs_] * self.n_components
216216

217217
def _get_covars(self):
218-
return [linalg.pinv(c) for c in self._get_precisions()]
218+
return [pinvh(c) for c in self._get_precisions()]
219219

220220
def _set_covars(self, covars):
221221
raise NotImplementedError("""The variational algorithm does
@@ -332,7 +332,7 @@ def _update_precisions(self, X, z):
332332
for k in xrange(self.n_components):
333333
diff = X - self.means_[k]
334334
self.scale_ += np.dot(diff.T, z[:, k:k + 1] * diff)
335-
self.scale_ = linalg.pinv(self.scale_)
335+
self.scale_ = pinvh(self.scale_)
336336
self.precs_ = self.dof_ * self.scale_
337337
self.det_scale_ = linalg.det(self.scale_)
338338
self.bound_prec_ = 0.5 * wishart_log_det(
@@ -346,7 +346,7 @@ def _update_precisions(self, X, z):
346346
self.scale_[k] = (sum_resp + 1) * np.identity(n_features)
347347
diff = X - self.means_[k]
348348
self.scale_[k] += np.dot(diff.T, z[:, k:k + 1] * diff)
349-
self.scale_[k] = linalg.pinv(self.scale_[k])
349+
self.scale_[k] = pinvh(self.scale_[k])
350350
self.precs_[k] = self.dof_[k] * self.scale_[k]
351351
self.det_scale_[k] = linalg.det(self.scale_[k])
352352
self.bound_prec_[k] = 0.5 * wishart_log_det(self.dof_[k],

sklearn/mixture/gmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ..base import BaseEstimator
1616
from ..utils import check_random_state, deprecated
17-
from ..utils.extmath import logsumexp
17+
from ..utils.extmath import logsumexp, pinvh
1818
from .. import cluster
1919

2020
EPS = np.finfo(float).eps
@@ -616,7 +616,7 @@ def _log_multivariate_normal_density_tied(X, means, covars):
616616
"""Compute Gaussian log-density at X for a tied model"""
617617
from scipy import linalg
618618
n_samples, n_dim = X.shape
619-
icv = linalg.pinv(covars)
619+
icv = pinvh(covars)
620620
lpr = -0.5 * (n_dim * np.log(2 * np.pi) + np.log(linalg.det(covars) + 0.1)
621621
+ np.sum(X * np.dot(X, icv), 1)[:, np.newaxis]
622622
- 2 * np.dot(np.dot(X, icv), means.T)

0 commit comments

Comments
 (0)