Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multidim outputs #266

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update ASDL dep
  • Loading branch information
wiseodd committed Dec 6, 2024
commit 28cccaa8ec6c7b8b694081ebfc7f49b2a2ce03fa
11 changes: 6 additions & 5 deletions examples/lm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,13 @@ def forward(self, x):

data = next(iter(dl)) # data[INPUT_KEY].shape = (B, L * D)
pred_map = model(data) # (B, D)
pred_la_mean, pred_la_var = la(data, pred_type=PredType.GLM)

# Detach the grad if you don't need to backprop
pred_la_mean, pred_la_var = pred_la_mean.detach(), pred_la_var.detach()
pred_la_mean, pred_la_var = la(data, pred_type=PredType.GLM)
# torch.Size([B, L, K]) torch.Size([B, L, K, K])
print(pred_la_mean.shape, pred_la_var.shape)

# torch.Size([4, 6, 1]) torch.Size([4, 6, 1, 1])
pred_la_mean, pred_la_var = la(data, pred_type=PredType.GLM, diagonal_output=True)
# torch.Size([B, L, K]) torch.Size([B, L, K])
print(pred_la_mean.shape, pred_la_var.shape)


Expand All @@ -88,5 +89,5 @@ def forward(self, x):
data, pred_type=PredType.NN, link_approx=LinkApprox.MC, n_samples=10
)

# torch.Size([4, 6, 1]) torch.Size([4, 6, 1])
# torch.Size([B, L, K]) torch.Size([B, L, K])
print(pred_la_mean.shape, pred_la_var.shape)
83 changes: 73 additions & 10 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
)

# log likelihood = g(loss)
self.loss: float = 0.0
self.loss: float | torch.Tensor = 0.0
self.n_outputs: int = 0
self.n_data: int = 0

Expand Down Expand Up @@ -234,10 +234,69 @@ def log_likelihood(self) -> torch.Tensor:
def __call__(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
pred_type: PredType | str,
link_approx: LinkApprox | str,
n_samples: int,
pred_type: PredType | str = PredType.GLM,
joint: bool = False,
link_approx: LinkApprox | str = LinkApprox.PROBIT,
n_samples: int = 1,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
fitting: bool = False,
**model_kwargs: dict[str, Any],
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Compute the posterior predictive on input data `x`.

Parameters
----------
x : torch.Tensor or MutableMapping
`(batch_size, input_shape)` if tensor. If MutableMapping, must contain
the said tensor.

pred_type : {'glm', 'nn'}, default='glm'
type of posterior predictive, linearized GLM predictive or neural
network sampling predictive. The GLM predictive is consistent with
the curvature approximations used here. When Laplace is done only
on subset of parameters (i.e. some grad are disabled),
only `nn` predictive is supported.

link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
how to approximate the classification link function for the `'glm'`.
For `pred_type='nn'`, only 'mc' is possible.

joint : bool
Whether to output a joint predictive distribution in regression with
`pred_type='glm'`. If set to `True`, the predictive distribution
has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
If `False`, then only outputs the marginal predictive distribution.
Only available for regression and GLM predictive.

n_samples : int
number of samples for `link_approx='mc'`.

diagonal_output : bool
whether to use a diagonalized posterior predictive on the outputs.
Only works for `pred_type='glm'` when `joint=False` in regression.
In the case of last-layer Laplace with a diagonal or Kron Hessian,
setting this to `True` makes computation much(!) faster for large
number of outputs.

generator : torch.Generator, optional
random number generator to control the samples (if sampling used).

fitting : bool, default=False
whether or not this predictive call is done during fitting. Only useful for
reward modeling: the likelihood is set to `"regression"` when `False` and
`"classification"` when `True`.

Returns
-------
predictive: torch.Tensor or tuple[torch.Tensor]
For `likelihood='classification'`, a torch.Tensor is returned with
a distribution over classes (similar to a Softmax).
For `likelihood='regression'`, a tuple of torch.Tensor is returned
with the mean and the predictive variance.
For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
is returned with the mean and the predictive covariance.
"""
raise NotImplementedError

def predictive(
Expand All @@ -247,7 +306,9 @@ def predictive(
link_approx: LinkApprox | str,
n_samples: int,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self(x, pred_type, link_approx, n_samples)
return self(
x, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples
)

def _check_jacobians(self, Js: torch.Tensor) -> None:
if not isinstance(Js, torch.Tensor):
Expand Down Expand Up @@ -492,7 +553,9 @@ def optimize_prior_precision(

def _gridsearch(
self,
loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float],
loss: torchmetrics.Metric
| Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
| Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
interval: torch.Tensor,
val_loader: DataLoader,
pred_type: PredType | str,
Expand Down Expand Up @@ -989,7 +1052,7 @@ def __call__(
pred_type: PredType | str = PredType.GLM,
joint: bool = False,
link_approx: LinkApprox | str = LinkApprox.PROBIT,
n_samples: int = 100,
n_samples: int = 1,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
fitting: bool = False,
Expand Down Expand Up @@ -1197,7 +1260,7 @@ def _nn_predictive_classification(
n_samples: int = 100,
**model_kwargs: dict[str, Any],
) -> torch.Tensor:
py = 0.0
py = torch.tensor(0.0)
for sample in self.sample(n_samples):
vector_to_parameters(sample, self.params)
logits = self.model(
Expand Down Expand Up @@ -2307,11 +2370,11 @@ def _glm_predictive_distribution(self, X: torch.Tensor, joint: bool = False):

def __call__(
self,
x: torch.Tensor | MutableMapping,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
pred_type: PredType | str = PredType.GP,
joint: bool = False,
link_approx: LinkApprox | str = LinkApprox.PROBIT,
n_samples: int = 100,
n_samples: int = 1,
diagonal_output: bool = False,
generator: torch.Generator | None = None,
fitting: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
license = { text = "MIT" }
requires-python = ">=3.9"
dependencies = [
"asdfghjkl == 0.1a4",
"asdfghjkl == 0.1a5",
"backpack-for-pytorch",
"curvlinops-for-pytorch >= 2.0",
"numpy",
Expand Down
26 changes: 4 additions & 22 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.