Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e79de49
Add shape checking utilities for coefficients of precomputed kernel r…
bhelfrecht May 4, 2021
2c9f4dd
Modify instantiation and fit call of KPCovR to accept pre-fitted regr…
bhelfrecht May 4, 2021
456166b
Update KPCovR tests to be compatible with new regressor usage
bhelfrecht May 4, 2021
a458d37
Update PCovR example notebook to be compatible with new regressor usage
bhelfrecht May 4, 2021
70188b6
Remove alpha from PCovR docstring
bhelfrecht May 4, 2021
65b9296
Reorganize regressor usage to pull kernel parameters directly from th…
bhelfrecht May 5, 2021
41aaff6
Pull alpha from the KPCovR regressor
bhelfrecht May 5, 2021
bfc04fa
Update KPCovR documentation
bhelfrecht May 5, 2021
a57bf7d
Update utilities documentation
bhelfrecht May 5, 2021
672fcf6
Make regressor default argument None, assign default within __init__
bhelfrecht May 5, 2021
3536cf5
Change inversions to use least squares with singular value cutoff bas…
bhelfrecht May 6, 2021
e96ccfc
Compute Yhat directly from the dual coefficients
bhelfrecht May 6, 2021
46a56e7
Move regressor checking to occur immediately
bhelfrecht May 6, 2021
36d2794
Add more details about pre-fitted regressors to PCovR and KPCovR docu…
bhelfrecht May 6, 2021
e8b13c5
Fix docstring of _check_dual_coefs
bhelfrecht May 6, 2021
9fa4988
Use KPCovR tolerance in matrix inversion instead of regularization
bhelfrecht May 6, 2021
bc69f4f
Add tests for KPCovR to cover the pre-fitted regressors
bhelfrecht May 6, 2021
585e9e5
Add PCovR test to check for regressor modifications
bhelfrecht May 6, 2021
a368bf5
Move default regressor assignment to fit and accept regressor params
bhelfrecht May 10, 2021
1308835
Reorganize KPCovR regressor infrastructure
bhelfrecht May 10, 2021
c53fee2
Make PCovR example compatible with new KPCovR regressor infrastructure
bhelfrecht May 10, 2021
8e3c12c
Add PCovR test for None regressor
bhelfrecht May 10, 2021
5cd66bb
Modify KPCovR tests for compatibility with new regressor infrastructure
bhelfrecht May 10, 2021
a27faa1
Add KPCovR test for None regressor
bhelfrecht May 10, 2021
95e0bab
Fix KPCovR docstring example
bhelfrecht May 12, 2021
8da0385
Consolidate regressor checking
bhelfrecht May 12, 2021
5aac7d1
Simplify tests for pre-fitted regressors
bhelfrecht May 12, 2021
350c15b
Negate KPCovR score according to sklearn guidelines
bhelfrecht May 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/PCovR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"from skcosmo.decomposition import PCovR\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.linear_model import Ridge\n",
"from sklearn.kernel_ridge import KernelRidge\n",
"\n",
"cmapX = cm.plasma\n",
"cmapy = cm.Greys"
Expand Down Expand Up @@ -182,7 +183,11 @@
"mixing = 0.5\n",
"kpcovr = KernelPCovR(\n",
" mixing=mixing,\n",
" alpha=1e-8,\n",
" regressor=KernelRidge(\n",
" alpha=1e-8,\n",
" kernel=\"rbf\",\n",
" gamma=0.1,\n",
" ),\n",
" kernel=\"rbf\",\n",
" gamma=0.1,\n",
" n_components=2,\n",
Expand Down
156 changes: 115 additions & 41 deletions skcosmo/decomposition/_kernel_pcovr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from scipy.sparse.linalg import svds
from sklearn.decomposition._base import _BasePCA
from sklearn.decomposition._pca import _infer_dimension
from sklearn.exceptions import NotFittedError
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model._base import LinearModel
from sklearn.metrics.pairwise import pairwise_kernels
from sklearn.utils import (
Expand All @@ -23,7 +25,10 @@
)

from ..preprocessing import KernelNormalizer
from ..utils import pcovr_kernel
from ..utils import (
check_krr_fit,
pcovr_kernel,
)


class KernelPCovR(_BasePCA, LinearModel):
Expand Down Expand Up @@ -75,10 +80,18 @@ class KernelPCovR(_BasePCA, LinearModel):
If randomized :
run randomized SVD by the method of Halko et al.

regressor : instance of `sklearn.kernel_ridge.KernelRidge`, default=None
The regressor to use for computing
the property predictions :math:`\\hat{\\mathbf{Y}}`.
A pre-fitted regressor may be provided.
If the regressor is not `None`, its kernel parameters
(`kernel`, `gamma`, `degree`, `coef0`, and `kernel_params`)
must be identical to those passed directly to `KernelPCovR`.

kernel: "linear" | "poly" | "rbf" | "sigmoid" | "cosine" | "precomputed"
Kernel. Default="linear".

gamma: float, default=1/n_features
gamma: float, default=None
Kernel coefficient for rbf, poly and sigmoid kernels. Ignored by other
kernels.

Expand All @@ -96,15 +109,13 @@ class KernelPCovR(_BasePCA, LinearModel):
center: bool, default=False
Whether to center any computed kernels

alpha: float, default=1E-6
Regularization parameter to use in all regression operations.

fit_inverse_transform: bool, default=False
Learn the inverse transform for non-precomputed kernels.
(i.e. learn to find the pre-image of a point)

tol: float, default=1e-12
Tolerance for singular values computed by svd_solver == 'arpack'.
Tolerance for singular values computed by svd_solver == 'arpack'
and for matrix inversions.
Must be of range [0.0, infinity).

n_jobs: int, default=None
Expand All @@ -121,6 +132,9 @@ class KernelPCovR(_BasePCA, LinearModel):
Used when the 'arpack' or 'randomized' solvers are used. Pass an int
for reproducible results across multiple function calls.

**regressor_params: additional keyword arguments to be passed
to the regressor. Ignored if `regressor` is not `None`.


Attributes
----------
Expand Down Expand Up @@ -154,53 +168,53 @@ class KernelPCovR(_BasePCA, LinearModel):
>>> import numpy as np
>>> from skcosmo.decomposition import KernelPCovR
>>> from skcosmo.preprocessing import StandardFlexibleScaler as SFS
>>> from sklearn.kernel_ridge import KernelRidge
>>>
>>> X = np.array([[-1, 1, -3, 1], [1, -2, 1, 2], [-2, 0, -2, -2], [1, 0, 2, -1]])
>>> X = SFS().fit_transform(X)
>>> Y = np.array([[ 0, -5], [-1, 1], [1, -5], [-3, 2]])
>>> Y = SFS(column_wise=True).fit_transform(Y)
>>>
>>> kpcovr = KernelPCovR(mixing=0.1, n_components=2, kernel='rbf', gamma=2)
>>> kpcovr = KernelPCovR(mixing=0.1, n_components=2, regressor=KernelRidge(kernel='rbf', gamma=1), kernel='rbf', gamma=1)
>>> kpcovr.fit(X, Y)
KernelPCovR(coef0=1, degree=3, fit_inverse_transform=False, gamma=0.01, kernel='rbf',
kernel_params=None, mixing=None, n_components=2, n_jobs=None,
alpha=None, tol=1e-12)
KernelPCovR(gamma=1, kernel='rbf', mixing=0.1, n_components=2,
regressor=KernelRidge(gamma=1, kernel='rbf'))
>>> T = kpcovr.transform(X)
[[ 1.01199065, -0.35439061],
[-0.68099591, 0.48912275],
[ 1.4677616 , 0.13757037],
[-1.79874193, -0.27232032]]
[[-0.61261285, -0.18937908],
[ 0.45242098, 0.25453465],
[-0.77871824, 0.04847559],
[ 0.91186937, -0.21211816]]
>>> Yp = kpcovr.predict(X)
[[-0.01044648, -0.84443158],
[-0.1758848 , 0.16224503],
[ 0.1573037 , -0.84211944],
[-0.51133139, 0.32552881]]
[[ 0.5100212 , -0.99488463],
[-0.18992219, 0.82064368],
[ 1.11923584, -1.04798016],
[-1.5635827 , 1.11078662]]
>>> kpcovr.score(X, Y)
(0.5312320029915978, 0.06254540655698511)
-0.520388347837897
"""

def __init__(
self,
mixing=0.5,
n_components=None,
svd_solver="auto",
regressor=None,
kernel="linear",
gamma=None,
degree=3,
coef0=1,
alpha=1e-6,
kernel_params=None,
center=False,
fit_inverse_transform=False,
tol=1e-12,
n_jobs=None,
iterated_power="auto",
random_state=None,
**regressor_params
):

self.mixing = mixing
self.n_components = n_components
self.alpha = alpha

self.svd_solver = svd_solver
self.tol = tol
Expand All @@ -209,15 +223,19 @@ def __init__(
self.center = center

self.kernel = kernel
self.kernel_params = kernel_params
self.gamma = gamma
self.degree = degree
self.coef0 = coef0
self.kernel_params = kernel_params

self.n_jobs = n_jobs
self.n_samples_ = None

self.fit_inverse_transform = fit_inverse_transform

self.regressor = regressor
self.regressor_params = regressor_params

def _get_kernel(self, X, Y=None):
if callable(self.kernel):
params = self.kernel_params or {}
Expand Down Expand Up @@ -252,9 +270,9 @@ def _fit(self, K, Yhat, W):
self.pkt_ = P @ U @ np.sqrt(np.diagflat(S_inv))

T = K @ self.pkt_
self.pt__ = np.linalg.lstsq(T, np.eye(T.shape[0]), rcond=self.alpha)[0]
self.pt__ = np.linalg.lstsq(T, np.eye(T.shape[0]), rcond=self.tol)[0]

def fit(self, X, Y, Yhat=None, W=None):
def fit(self, X, Y):
"""

Fit the model with X and Y.
Expand All @@ -279,18 +297,16 @@ def fit(self, X, Y, Yhat=None, W=None):
to have unit variance, otherwise :math:`\\mathbf{Y}` should be
scaled so that each feature has a variance of 1 / n_features.

Yhat: ndarray, shape (n_samples, n_properties), optional
Regressed training data, where n_samples is the number of samples and
n_properties is the number of properties. If not supplied, computed
by ridge regression.

Returns
-------
self: object
Returns the instance itself.

"""

if self.regressor is not None and not isinstance(self.regressor, KernelRidge):
raise ValueError("Regressor must be an instance of `KernelRidge`")

X, Y = check_X_y(X, Y, y_numeric=True, multi_output=True)
self.X_fit_ = X.copy()

Expand All @@ -308,14 +324,66 @@ def fit(self, X, Y, Yhat=None, W=None):

self.n_samples_ = X.shape[0]

if W is None:
if Yhat is None:
W = (np.linalg.lstsq(K, Y, rcond=self.alpha)[0]).reshape(X.shape[0], -1)
else:
W = np.linalg.lstsq(K, Yhat, rcond=self.alpha)[0]
if self.regressor is None:
regressor = KernelRidge(
kernel=self.kernel,
gamma=self.gamma,
degree=self.degree,
coef0=self.coef0,
kernel_params=self.kernel_params,
**self.regressor_params,
)
else:
regressor = self.regressor
kernel_attrs = ["kernel", "gamma", "degree", "coef0", "kernel_params"]
if not all(
[
getattr(self, attr) == getattr(regressor, attr)
for attr in kernel_attrs
]
):
raise ValueError(
"Kernel parameter mismatch: the regressor has kernel parameters {%s}"
" and KernelPCovR was initialized with kernel parameters {%s}"
% (
", ".join(
[
"%s: %r" % (attr, getattr(regressor, attr))
for attr in kernel_attrs
]
),
", ".join(
[
"%s: %r" % (attr, getattr(self, attr))
for attr in kernel_attrs
]
),
)
)

# Check if regressor is fitted; if not, fit with precomputed K
# to avoid needing to compute the kernel a second time
self.regressor_ = check_krr_fit(regressor, K, X, Y)

if Yhat is None:
Yhat = K @ W
W = self.regressor_.dual_coef_.reshape(X.shape[0], -1)

# Use this instead of `self.regressor_.predict(K)`
# so that we can handle the case of the pre-fitted regressor
Yhat = K @ W

# When we have an unfitted regressor,
# we fit it with a precomputed K
# so we must subsequently "reset" it so that
# it will work on the particular X
# of the KPCovR call. The dual coefficients are kept.
# Can be bypassed if the regressor is pre-fitted.
try:
check_is_fitted(regressor)

except NotFittedError:
self.regressor_.set_params(**regressor.get_params())
self.regressor_.X_fit_ = self.X_fit_
self.regressor_._check_n_features(self.X_fit_, reset=True)

# Handle svd_solver
self._fit_svd_solver = self.svd_solver
Expand Down Expand Up @@ -408,7 +476,7 @@ def inverse_transform(self, T):

def score(self, X, Y):
r"""
Computes the loss values for KernelPCovR on the given predictor and
Computes the (negative) loss values for KernelPCovR on the given predictor and
response variables. The loss in :math:`\mathbf{K}`, as explained in
[Helfrecht2020]_ does not correspond to a traditional Gram loss
:math:`\mathbf{K} - \mathbf{TT}^T`. Indicating the kernel between set
Expand All @@ -424,15 +492,17 @@ def score(self, X, Y):
\mathbf{K}_{NN} \mathbf{T}_N (\mathbf{T}_N^T \mathbf{T}_N)^{-1}
\mathbf{T}_V^T\right]}{\operatorname{Tr}(\mathbf{K}_{VV})}

The negative loss is returned for easier use in sklearn pipelines, e.g., a grid search, where methods named 'score' are meant to be maximized.

Arguments
---------
X: independent (predictor) variable
Y: dependent (response) variable

Returns
-------
Lk: KPCA loss, determined by the reconstruction of the kernel
Ly: KR loss
L: Negative sum of the KPCA and KRR losses, with the KPCA loss
determined by the reconstruction of the kernel

"""

Expand All @@ -455,10 +525,14 @@ def score(self, X, Y):
t_n = K_NN @ self.pkt_
t_v = K_VN @ self.pkt_

w = t_n @ np.linalg.pinv(t_n.T @ t_n, rcond=self.alpha) @ t_v.T
w = (
t_n
@ np.linalg.lstsq(t_n.T @ t_n, np.eye(t_n.shape[1]), rcond=self.tol)[0]
@ t_v.T
)
Lkpca = np.trace(K_VV - 2 * K_VN @ w + w.T @ K_VV @ w) / np.trace(K_VV)

return sum([Lkpca, Lkrr])
return -sum([Lkpca, Lkrr])

def _decompose_truncated(self, mat):

Expand Down
Loading