Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform y into 2D tensor when y.ndim == 1 and likelihood == REGRESSION #240

Merged
merged 4 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,13 @@ def fit(
else:
X, y = data
X, y = X.to(self._device), y.to(self._device)

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

self.model.zero_grad()
loss_batch, H_batch = self._curv_closure(X, y, N=N)
self.loss += loss_batch
Expand Down Expand Up @@ -1761,12 +1768,19 @@ def fit(
if not self.enable_backprop:
self.mean = self.mean.detach()

X, _ = next(iter(train_loader))
X, y = next(iter(train_loader))
with torch.no_grad():
try:
out = self.model(X[:1].to(self._device))
except (TypeError, AttributeError):
out = self.model(X.to(self._device))

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

self.n_outputs = out.shape[-1]
setattr(self.model, "output_size", self.n_outputs)

Expand Down Expand Up @@ -1930,7 +1944,7 @@ class FunctionalLaplace(BaseLaplace):
See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
for more details.

Note that for `likelihood='classification'`, we approximate \( L_{NN} \\) with a diagonal matrix
Note that for `likelihood='classification'`, we approximate \\( L_{NN} \\) with a diagonal matrix
( \\( L_{NN} \\) is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t.
neural network output \\( f \\), See Appendix [A.2.1](https://arxiv.org/abs/2008.08400) for exact definition). We
resort to such an approximation because of the (possible) errors found in Laplace approximation for
Expand Down Expand Up @@ -2023,9 +2037,9 @@ def _check_prior_precision(prior_precision: float | torch.Tensor):

def _init_K_MM(self):
"""Allocates memory for the kernel matrix evaluated at the subset of the training
data points. If the subset is of size \(M\) and the problem has \(C\) outputs,
this is a list of C \((M,M\)) tensors for diagonal kernel and \((M x C, M x C)\)
otherwise.
data points. If the subset is of size \\(M\\) and the problem has \\(C\\) outputs,
this is a list of C \\((M,M\\)) tensors for diagonal kernel and
\\((M \\times C, M \\times C)\\) otherwise.
"""
if self.independent_outputs:
self.K_MM = [
Expand All @@ -2040,9 +2054,9 @@ def _init_K_MM(self):

def _init_Sigma_inv(self):
"""Allocates memory for the cholesky decomposition of
\[
K_{MM} + \Lambda_{MM}^{-1}.
\]
\\[
K_{MM} + \\Lambda_{MM}^{-1}.
\\]
See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
Equation 15 for more information.
"""
Expand Down Expand Up @@ -2115,13 +2129,13 @@ class for more details.

def _build_Sigma_inv(self):
"""Computes the cholesky decomposition of
\[
K_{MM} + \Lambda_{MM}^{-1}.
\]
\\[
K_{MM} + \\Lambda_{MM}^{-1}.
\\]
See See [Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021)](https://arxiv.org/abs/2008.08400)
Equation 15 for more information.

As the diagonal approximation is performed with \Lambda_{MM} (which is stored in self.L),
As the diagonal approximation is performed with \\(\\Lambda_{MM}\\) (which is stored in self.L),
the code is greatly simplified.
"""
if self.independent_outputs:
Expand Down Expand Up @@ -2231,10 +2245,16 @@ def fit(

Js_batch, f_batch = self._jacobians(X, enable_backprop=False)

if self.likelihood == Likelihood.REGRESSION and y.ndim != out.ndim:
raise ValueError(
f"The model's output has {out.ndim} dims but "
f"the target has {y.ndim} dims."
)

with torch.no_grad():
loss_batch = self.backend.factor * self.backend.lossfunc(f_batch, y)

if self.likelihood == "regression":
if self.likelihood == Likelihood.REGRESSION:
b, C = f_batch.shape
lambdas_batch = torch.unsqueeze(torch.eye(C), 0).repeat(b, 1, 1)
else:
Expand Down Expand Up @@ -2552,11 +2572,11 @@ def log_det_ratio(self) -> torch.Tensor:
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with
(note that we always use diagonal approximation \\(D\\) of the Hessian of log likelihood w.r.t. \\(f\\)):

log determinant term := \\( \log | I + D^{1/2}K D^{1/2} | \\)
log determinant term := \\( \\log | I + D^{1/2}K D^{1/2} | \\)

For `regression`, we use ["standard" GP marginal likelihood](https://stats.stackexchange.com/questions/280105/log-marginal-likelihood-for-gaussian-process):

log determinant term := \\( \log | K + \\sigma_2 I | \\)
log determinant term := \\( \\log | K + \\sigma_2 I | \\)
"""
if self.likelihood == Likelihood.REGRESSION:
if self.independent_outputs:
Expand Down Expand Up @@ -2596,7 +2616,7 @@ def scatter(self, eps: float = 0.00001) -> torch.Tensor:
"""Compute scatter term in GP log marginal likelihood.

For `classification` we use eq. (3.44) from Chapter 3.5 from
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\hat{f} = f \\):
[GP book R&W 2006](http://www.gaussianprocess.org/gpml/chapters/) with \\(\\hat{f} = f \\):

scatter term := \\( f K^{-1} f^{T} \\)

Expand Down
39 changes: 38 additions & 1 deletion tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tests.utils import ListDataset, dict_data_collator, jacobians_naive

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

flavors = [FullLaplace, KronLaplace, DiagLaplace]
if find_spec("asdfghjkl") is not None:
Expand All @@ -43,6 +43,16 @@ def model():
return model


@pytest.fixture
def model_1d():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 1))
setattr(model, "output_size", 1)
model_params = list(model.parameters())
setattr(model, "n_layers", len(model_params)) # number of parameter groups
setattr(model, "n_params", len(parameters_to_vector(model_params)))
return model


@pytest.fixture
def large_model():
model = wide_resnet50_2()
Expand Down Expand Up @@ -113,6 +123,22 @@ def reg_loader():
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn(10, 1)
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d_flat():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn((10,))
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def custom_loader_clf():
data = []
Expand Down Expand Up @@ -818,3 +844,14 @@ def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader

# Should not raise an error
lap.optimize_prior_precision(method="gridsearch", val_loader=dataloader, n_steps=10)


@pytest.mark.parametrize("laplace", flavors)
def test_parametric_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat, laplace):
lap = laplace(model_1d, likelihood="regression")
lap.fit(reg_loader_1d) # OK

lap2 = laplace(model_1d, likelihood="regression")

with pytest.raises(ValueError):
lap2.fit(reg_loader_1d_flat)
36 changes: 36 additions & 0 deletions tests/test_functional_laplace_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ def reg_loader():
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn(10, 1)
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def reg_loader_1d_flat():
torch.manual_seed(9999)
X = torch.randn(10, 3)
y = torch.randn((10,))
return DataLoader(TensorDataset(X, y), batch_size=3)


@pytest.fixture
def model():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 2))
Expand All @@ -24,6 +40,16 @@ def model():
return model


@pytest.fixture
def model_1d():
model = torch.nn.Sequential(nn.Linear(3, 20), nn.Linear(20, 1))
setattr(model, "output_size", 1)
model_params = list(model.parameters())
setattr(model, "n_layers", len(model_params)) # number of parameter groups
setattr(model, "n_params", len(parameters_to_vector(model_params)))
return model


@pytest.fixture
def reg_Xy():
torch.manual_seed(711)
Expand Down Expand Up @@ -286,3 +312,13 @@ def mock_jacobians(self, x):
expected_block_diagonal_kernel,
block_diag_kernel.to(expected_block_diagonal_kernel.dtype),
)


def test_functional_fit_y_shape(model_1d, reg_loader_1d, reg_loader_1d_flat):
la = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False)
la.fit(reg_loader_1d)

la2 = FunctionalLaplace(model_1d, "regression", 10, independent_outputs=False)

with pytest.raises(ValueError):
la2.fit(reg_loader_1d_flat)
2 changes: 1 addition & 1 deletion tests/test_laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from laplace.lllaplace import DiagLLLaplace, FullLLLaplace, KronLLLaplace

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
flavors = [
FullLaplace,
KronLaplace,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from laplace.utils import kron as kron_prod
from tests.utils import get_diag_psd_matrix, get_psd_matrix, jacobians_naive

torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)

lrlaplace_param = pytest.param(
LowRankLaplace, marks=pytest.mark.xfail(reason="Unimplemented in the new ASDL")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_subnetlaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
score_based_subnet_masks = [
RandomSubnetMask,
LargestMagnitudeSubnetMask,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_subset_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from laplace.curvature.curvlinops import CurvlinopsEF, CurvlinopsGGN, CurvlinopsHessian

torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.double)
flavors = [KronLaplace, DiagLaplace, FullLaplace]
valid_backends = [CurvlinopsGGN, CurvlinopsEF, AsdlGGN, AsdlEF]

Expand Down