Skip to content

Commit

Permalink
ENH GPR.log_marginal_likelihood() returns the current log-likelihood …
Browse files Browse the repository at this point in the history
…if no theta vector is provided
  • Loading branch information
Jan Hendrik Metzen authored and glouppe committed Oct 19, 2015
1 parent c3a41a3 commit 3957fc3
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 22 deletions.
52 changes: 41 additions & 11 deletions sklearn/gaussian_process/gpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def optimizer(obj_func, initial_theta, bounds):
Square root of W, the Hessian of log-likelihood of the latent function
values for the observed labels. Since W is diagonal, only the diagonal
of sqrt(W) is stored.
log_marginal_likelihood_value_: float
The log-marginal-likelihood of self.kernel_.theta
"""
def __init__(self, kernel=None, optimizer="fmin_l_bfgs_b",
n_restarts_optimizer=0, max_iter_predict=100,
Expand Down Expand Up @@ -185,9 +188,7 @@ def fit(self, X, y):
raise ValueError("{0:s} requires 2 classes.".format(
self.__class__.__name__))

if self.kernel_.n_dims == 0: # no tunable hyperparameters
pass
elif self.optimizer is not None:
if self.optimizer is not None and self.kernel_.n_dims > 0:
# Choose hyperparameters based on maximizing the log-marginal
# likelihood (potentially starting from several initial values)
def obj_func(theta, eval_gradient=True):
Expand Down Expand Up @@ -221,6 +222,10 @@ def obj_func(theta, eval_gradient=True):
# likelihood
lml_values = map(itemgetter(1), optima)
self.kernel_.theta = optima[np.argmin(lml_values)][0]
self.log_marginal_likelihood_value_ = -np.min(lml_values)
else:
self.log_marginal_likelihood_value_ = \
self.log_marginal_likelihood(self.kernel_.theta)

# Precompute quantities required for predictions which are independent
# of actual query points
Expand Down Expand Up @@ -292,19 +297,20 @@ def predict_proba(self, X):

return np.vstack((1 - pi_star, pi_star)).T

def log_marginal_likelihood(self, theta, eval_gradient=False):
def log_marginal_likelihood(self, theta=None, eval_gradient=False):
"""Returns log-marginal likelihood of theta for training data.
Parameters
----------
theta : array-like, shape = (n_kernel_params,)
theta : array-like, shape = (n_kernel_params,) or None
Kernel hyperparameters for which the log-marginal likelihood is
evaluated
evaluated. If None, the precomputed log_marginal_likelihood
of self.kernel_.theta is returned.
eval_gradient : bool, default: False
If True, the gradient of the log-marginal likelihood with respect
to the kernel hyperparameters at position theta is returned
additionally.
additionally. If True, theta must not be None.
Returns
-------
Expand All @@ -316,6 +322,12 @@ def log_marginal_likelihood(self, theta, eval_gradient=False):
hyperparameters at position theta.
Only returned when eval_gradient is True.
"""
if theta is None:
if eval_gradient:
raise ValueError(
"Gradient can only be evaluated for theta!=None")
return self.log_marginal_likelihood_value_

kernel = self.kernel_.clone_with_theta(theta)

if eval_gradient:
Expand Down Expand Up @@ -526,6 +538,9 @@ def optimizer(obj_func, initial_theta, bounds):
classification, a CompoundKernel is returned which consists of the
different kernels used in the one-versus-rest classifiers.
log_marginal_likelihood_value_: float
The log-marginal-likelihood of self.kernel_.theta
classes_ : array-like, shape = (n_classes,)
Unique class labels.
Expand Down Expand Up @@ -589,6 +604,14 @@ def fit(self, X, y):

self.base_estimator_.fit(X, y)

if self.n_classes_ > 2:
self.log_marginal_likelihood_value_ = np.mean(
[estimator.log_marginal_likelihood()
for estimator in self.base_estimator_.estimators_])
else:
self.log_marginal_likelihood_value_ = \
self.base_estimator_.log_marginal_likelihood()

return self

def predict(self, X):
Expand Down Expand Up @@ -638,26 +661,27 @@ def kernel_(self):
[estimator.kernel_
for estimator in self.base_estimator_.estimators_])

def log_marginal_likelihood(self, theta, eval_gradient=False):
def log_marginal_likelihood(self, theta=None, eval_gradient=False):
"""Returns log-marginal likelihood of theta for training data.
In the case of multi-class classification, the mean log-marginal
likelihood of the one-versus-rest classifiers are returned.
Parameters
----------
theta : array-like, shape = (n_kernel_params,)
theta : array-like, shape = (n_kernel_params,) or none
Kernel hyperparameters for which the log-marginal likelihood is
evaluated. In the case of multi-class classification, theta may
be the hyperparameters of the compound kernel or of an individual
kernel. In the latter case, all individual kernel get assigned the
same theta values.
same theta values. If None, the precomputed log_marginal_likelihood
of self.kernel_.theta is returned.
eval_gradient : bool, default: False
If True, the gradient of the log-marginal likelihood with respect
to the kernel hyperparameters at position theta is returned
additionally. Note that gradient computation is not supported
for non-binary classification.
for non-binary classification. If True, theta must not be None.
Returns
-------
Expand All @@ -671,6 +695,12 @@ def log_marginal_likelihood(self, theta, eval_gradient=False):
"""
check_is_fitted(self, ["classes_", "n_classes_"])

if theta is None:
if eval_gradient:
raise ValueError(
"Gradient can only be evaluated for theta!=None")
return self.log_marginal_likelihood_value_

theta = np.asarray(theta)
if self.n_classes_ == 2:
return self.base_estimator_.log_marginal_likelihood(
Expand Down
26 changes: 19 additions & 7 deletions sklearn/gaussian_process/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def optimizer(obj_func, initial_theta, bounds):
alpha_: array-like, shape = (n_samples,)
Dual coefficients of training data points in kernel space
log_marginal_likelihood_value_: float
The log-marginal-likelihood of self.kernel_.theta
"""
def __init__(self, kernel=None, alpha=1e-10,
optimizer="fmin_l_bfgs_b", n_restarts_optimizer=0,
Expand Down Expand Up @@ -176,9 +179,7 @@ def fit(self, X, y):
self.X_train_ = np.copy(X) if self.copy_X_train else X
self.y_train_ = np.copy(y) if self.copy_X_train else y

if self.kernel_.n_dims == 0: # no tunable hyperparameters
pass
elif self.optimizer is not None:
if self.optimizer is not None and self.kernel_.n_dims > 0:
# Choose hyperparameters based on maximizing the log-marginal
# likelihood (potentially starting from several initial values)
def obj_func(theta, eval_gradient=True):
Expand Down Expand Up @@ -212,6 +213,10 @@ def obj_func(theta, eval_gradient=True):
# likelihood
lml_values = map(itemgetter(1), optima)
self.kernel_.theta = optima[np.argmin(lml_values)][0]
self.log_marginal_likelihood_value_ = -np.min(lml_values)
else:
self.log_marginal_likelihood_value_ = \
self.log_marginal_likelihood(self.kernel_.theta)

# Precompute quantities required for predictions which are independent
# of actual query points
Expand Down Expand Up @@ -334,19 +339,20 @@ def sample_y(self, X, n_samples=1, random_state=0):
y_samples = np.hstack(y_samples)
return y_samples

def log_marginal_likelihood(self, theta, eval_gradient=False):
def log_marginal_likelihood(self, theta=None, eval_gradient=False):
"""Returns log-marginal likelihood of theta for training data.
Parameters
----------
theta : array-like, shape = (n_kernel_params,)
theta : array-like, shape = (n_kernel_params,) or None
Kernel hyperparameters for which the log-marginal likelihood is
evaluated
evaluated. If None, the precomputed log_marginal_likelihood
of self.kernel_.theta is returned.
eval_gradient : bool, default: False
If True, the gradient of the log-marginal likelihood with respect
to the kernel hyperparameters at position theta is returned
additionally.
additionally. If True, theta must not be None.
Returns
-------
Expand All @@ -358,6 +364,12 @@ def log_marginal_likelihood(self, theta, eval_gradient=False):
hyperparameters at position theta.
Only returned when eval_gradient is True.
"""
if theta is None:
if eval_gradient:
raise ValueError(
"Gradient can only be evaluated for theta!=None")
return self.log_marginal_likelihood_value_

kernel = self.kernel_.clone_with_theta(theta)

if eval_gradient:
Expand Down
16 changes: 14 additions & 2 deletions sklearn/gaussian_process/tests/test_gpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C

from sklearn.utils.testing import (assert_true, assert_greater,
from sklearn.utils.testing import (assert_true, assert_greater, assert_equal,
assert_almost_equal, assert_array_equal)


Expand All @@ -26,7 +26,8 @@ def f(x):
y_mc[fX > 0.35] = 2


kernels = [RBF(length_scale=0.1),
fixed_kernel = RBF(length_scale=1.0, length_scale_bounds="fixed")
kernels = [RBF(length_scale=0.1), fixed_kernel,
RBF(length_scale=1.0, length_scale_bounds=(1e-3, 1e3)),
C(1.0, (1e-2, 1e2))
* RBF(length_scale=1.0, length_scale_bounds=(1e-3, 1e3))]
Expand All @@ -44,14 +45,24 @@ def test_predict_consistent():
def test_lml_improving():
""" Test that hyperparameter-tuning improves log-marginal likelihood. """
for kernel in kernels:
if kernel == fixed_kernel: continue
gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y)
assert_greater(gpc.log_marginal_likelihood(gpc.kernel_.theta),
gpc.log_marginal_likelihood(kernel.theta))


def test_lml_precomputed():
""" Test that lml of optimized kernel is stored correctly. """
for kernel in kernels:
gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y)
assert_equal(gpc.log_marginal_likelihood(gpc.kernel_.theta),
gpc.log_marginal_likelihood())


def test_converged_to_local_maximum():
""" Test that we are in local maximum after hyperparameter-optimization."""
for kernel in kernels:
if kernel == fixed_kernel: continue
gpc = GaussianProcessClassifier(kernel=kernel).fit(X, y)

lml, lml_gradient = \
Expand Down Expand Up @@ -117,6 +128,7 @@ def optimizer(obj_func, initial_theta, bounds):
return theta_opt, func_min

for kernel in kernels:
if kernel == fixed_kernel: continue
gpc = GaussianProcessClassifier(kernel=kernel, optimizer=optimizer)
gpc.fit(X, y_mc)
# Checks that optimizer improved marginal likelihood
Expand Down
16 changes: 14 additions & 2 deletions sklearn/gaussian_process/tests/test_gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def f(x):
X2 = np.atleast_2d([2., 4., 5.5, 6.5, 7.5]).T
y = f(X).ravel()


kernels = [RBF(length_scale=1.0),
fixed_kernel = RBF(length_scale=1.0, length_scale_bounds="fixed")
kernels = [RBF(length_scale=1.0), fixed_kernel,
RBF(length_scale=1.0, length_scale_bounds=(1e-3, 1e3)),
C(1.0, (1e-2, 1e2))
* RBF(length_scale=1.0, length_scale_bounds=(1e-3, 1e3)),
Expand All @@ -48,14 +48,24 @@ def test_gpr_interpolation():
def test_lml_improving():
""" Test that hyperparameter-tuning improves log-marginal likelihood. """
for kernel in kernels:
if kernel == fixed_kernel: continue
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)
assert_greater(gpr.log_marginal_likelihood(gpr.kernel_.theta),
gpr.log_marginal_likelihood(kernel.theta))


def test_lml_precomputed():
""" Test that lml of optimized kernel is stored correctly. """
for kernel in kernels:
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)
assert_equal(gpr.log_marginal_likelihood(gpr.kernel_.theta),
gpr.log_marginal_likelihood())


def test_converged_to_local_maximum():
""" Test that we are in local maximum after hyperparameter-optimization."""
for kernel in kernels:
if kernel == fixed_kernel: continue
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)

lml, lml_gradient = \
Expand All @@ -69,6 +79,7 @@ def test_converged_to_local_maximum():
def test_solution_inside_bounds():
""" Test that hyperparameter-optimization remains in bounds"""
for kernel in kernels:
if kernel == fixed_kernel: continue
gpr = GaussianProcessRegressor(kernel=kernel).fit(X, y)

bounds = gpr.kernel_.bounds
Expand Down Expand Up @@ -270,6 +281,7 @@ def optimizer(obj_func, initial_theta, bounds):
return theta_opt, func_min

for kernel in kernels:
if kernel == fixed_kernel: continue
gpr = GaussianProcessRegressor(kernel=kernel, optimizer=optimizer)
gpr.fit(X, y)
# Checks that optimizer improved marginal likelihood
Expand Down

0 comments on commit 3957fc3

Please sign in to comment.