Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable BayesOpt #120

Merged
merged 15 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support joint predictive distribution, ala GP
  • Loading branch information
wiseodd committed Mar 9, 2023
commit 6066b998d179632fffa70b014966af0384ebc2ec
17 changes: 15 additions & 2 deletions examples/regression_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,22 @@ def get_model():
f'prior precision={la.prior_precision.item():.2f}')

x = X_test.flatten().cpu().numpy()
f_mu, f_var = la(X_test)

# Two options:
# 1.) Marginal predictive distribution N(f_map(x_i), var(x_i))
# The mean is (m,k), the var is (m,k,k)
f_mu, f_var = la(X_test)

# 2.) Joint pred. dist. N((f_map(x_1),...,f_map(x_m)), Cov(f(x_1),...,f(x_m)))
# The mean is (m*k,) where k is the output dim. The cov is (m*k,m*k)
f_mu_joint, f_cov = la(X_test, joint=True)

# Both should be true
print(torch.allclose(f_mu.flatten(), f_mu_joint))
print(torch.allclose(f_var.flatten(), f_cov.diag()))

f_mu = f_mu.squeeze().detach().cpu().numpy()
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
f_sigma = f_var.squeeze().detach().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)

plot_regression(X_train, y_train, x, f_mu, pred_std,
Expand Down
50 changes: 45 additions & 5 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,8 @@ def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None):

return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

def __call__(self, x, pred_type='glm', link_approx='probit', n_samples=100,
diagonal_output=False, generator=None):
def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',
n_samples=100, diagonal_output=False, generator=None):
"""Compute the posterior predictive on input data `x`.

Parameters
Expand All @@ -514,6 +514,12 @@ def __call__(self, x, pred_type='glm', link_approx='probit', n_samples=100,
how to approximate the classification link function for the `'glm'`.
For `pred_type='nn'`, only 'mc' is possible.

joint : bool
Whether to output a joint predictive distribution in regression with
`pred_type='glm'`. If set to `True`, the predictive distribution
has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
If `False`, then only outputs the marginal predictive distribution.

n_samples : int
number of samples for `link_approx='mc'`.

Expand All @@ -531,6 +537,8 @@ def __call__(self, x, pred_type='glm', link_approx='probit', n_samples=100,
a distribution over classes (similar to a Softmax).
For `likelihood='regression'`, a tuple of torch.Tensor is returned
with the mean and the predictive variance.
For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
is returned with the mean and the predictive covariance.
"""
if pred_type not in ['glm', 'nn']:
raise ValueError('Only glm and nn supported as prediction types.')
Expand All @@ -546,7 +554,7 @@ def __call__(self, x, pred_type='glm', link_approx='probit', n_samples=100,
raise ValueError('Invalid random generator (check type and device).')

if pred_type == 'glm':
f_mu, f_var = self._glm_predictive_distribution(x)
f_mu, f_var = self._glm_predictive_distribution(x, joint=joint)
# regression
if self.likelihood == 'regression':
return f_mu, f_var
Expand Down Expand Up @@ -628,9 +636,15 @@ def predictive_samples(self, x, pred_type='glm', n_samples=100,
return self._nn_predictive_samples(x, n_samples)

@torch.enable_grad()
def _glm_predictive_distribution(self, X):
def _glm_predictive_distribution(self, X, joint=False):
Js, f_mu = self.backend.jacobians(X)
f_var = self.functional_variance(Js)

if joint:
f_mu = f_mu.flatten() # (batch*out)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be risky when we have multi-output regression or do we not support this anyway?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It follows the standard practice in NTK papers that we have f_mu is of shape (nk,) and f_cov is (nk, nk), no? Do you have another idea?

f_var = self.functional_covariance(Js) # (batch*out, batch*out)
else:
f_var = self.functional_variance(Js)

return f_mu, f_var

def _nn_predictive_samples(self, X, n_samples=100):
Expand Down Expand Up @@ -666,6 +680,27 @@ def functional_variance(self, Jacs):
"""
raise NotImplementedError

def functional_covariance(self, Jacs):
"""Compute functional covariance for the `'glm'` predictive:
`f_cov = Jacs @ P.inv() @ Jacs.T`, which is a batch*output x batch*output
predictive covariance matrix.

This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
Useful for joint predictions, such as in batched Bayesian optimization.

Parameters
----------
Jacs : torch.Tensor
Jacobians of model output wrt parameters
`(batch*outputs, parameters)`

Returns
-------
f_cov : torch.Tensor
output covariance `(batch*outputs, batch*outputs)`
"""
raise NotImplementedError

def sample(self, n_samples=100):
"""Sample from the Laplace posterior approximation, i.e.,
\\( \\theta \\sim \\mathcal{N}(\\theta_{MAP}, P^{-1})\\).
Expand Down Expand Up @@ -778,6 +813,11 @@ def square_norm(self, value):
def functional_variance(self, Js):
return torch.einsum('ncp,pq,nkq->nck', Js, self.posterior_covariance, Js)

def functional_covariance(self, Js):
n_batch, n_outs, n_params = Js.shape
Js = Js.reshape(n_batch*n_outs, n_params)
return torch.einsum('np,pq,mq->nm', Js, self.posterior_covariance, Js)

def sample(self, n_samples=100):
dist = MultivariateNormal(loc=self.mean, scale_tril=self.posterior_scale)
return dist.sample((n_samples,))
Expand Down