Skip to content

Commit

Permalink
Add output-dim check for low rank Laplace
Browse files Browse the repository at this point in the history
  • Loading branch information
wiseodd committed Sep 12, 2024
1 parent 0331f05 commit 99c1ebd
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,8 @@ def fit(

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output is of shape {tuple(out.shape)} but "
f"the target has shape {tuple(y.shape)}."
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

self.model.zero_grad()
Expand Down Expand Up @@ -1768,12 +1768,19 @@ def fit(
if not self.enable_backprop:
self.mean = self.mean.detach()

X, _ = next(iter(train_loader))
X, y = next(iter(train_loader))
with torch.no_grad():
try:
out = self.model(X[:1].to(self._device))
except (TypeError, AttributeError):
out = self.model(X.to(self._device))

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

self.n_outputs = out.shape[-1]
setattr(self.model, "output_size", self.n_outputs)

Expand Down Expand Up @@ -2240,8 +2247,8 @@ def fit(

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output is of shape {tuple(out.shape)} but "
f"the target has shape {tuple(y.shape)}."
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

with torch.no_grad():
Expand Down

0 comments on commit 99c1ebd

Please sign in to comment.