Skip to content

Commit

Permalink
Make FunctionalLaplace respects dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
wiseodd committed Sep 20, 2024
1 parent 95bc1a4 commit 43ac1ec
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 14 deletions.
4 changes: 3 additions & 1 deletion laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,7 +2299,9 @@ def fit(

if self.likelihood == Likelihood.REGRESSION:
b, C = f_batch.shape
lambdas_batch = torch.unsqueeze(torch.eye(C), 0).repeat(b, 1, 1)
lambdas_batch = torch.unsqueeze(
torch.eye(C, device=self._device, dtype=self._dtype), 0
).repeat(b, 1, 1)
else:
# second derivative of log lik is diag(p) - pp^T
ps = torch.softmax(f_batch, dim=-1)
Expand Down
6 changes: 4 additions & 2 deletions laplace/curvature/curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ def last_layer_jacobians(
output_size = int(f.numel() / bsize)

# calculate Jacobians using the feature vector 'phi'
p = next(self.model.parameters())
identity = (
torch.eye(output_size, device=next(self.model.parameters()).device)
torch.eye(output_size, device=p.device, dtype=p.dtype)
.unsqueeze(0)
.tile(bsize, 1, 1)
)
Expand Down Expand Up @@ -345,7 +346,8 @@ def _get_mc_functional_fisher(self, f: torch.Tensor) -> torch.Tensor:

for _ in range(self.num_samples):
if self.likelihood == "regression":
y_sample = f + torch.randn(f.shape, device=f.device) # N(y | f, 1)
# N(y | f, 1)
y_sample = f + torch.randn(f.shape, device=f.device, dtype=f.dtype)
grad_sample = f - y_sample # functional MSE grad
else: # classification with softmax
y_sample = torch.distributions.Multinomial(logits=f).sample()
Expand Down
5 changes: 3 additions & 2 deletions laplace/curvature/curvlinops.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def full(
check_deterministic=False,
**curvlinops_kwargs,
)

p = next(self.model.parameters())
H = torch.as_tensor(
linop @ torch.eye(linop.shape[0]),
device=next(self.model.parameters()).device,
linop @ torch.eye(linop.shape[0]), device=p.device, dtype=p.dtype
)

f = self.model(x)
Expand Down
9 changes: 6 additions & 3 deletions laplace/marglik_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def marglik_training(
optimizer_kwargs["weight_decay"] = 0.0

# get device, data set size N, number of layers H, number of parameters P
device = parameters_to_vector(model.parameters()).device
p = next(model.parameters())
device, dtype = p.device, p.dtype
N = len(train_loader.dataset)
trainable_params = [p for p in model.parameters() if p.requires_grad]
H = len(trainable_params)
Expand All @@ -170,7 +171,7 @@ def marglik_training(
# prior precision
log_prior_prec_init = np.log(temperature * prior_prec_init)
log_prior_prec = fix_prior_prec_structure(
log_prior_prec_init, prior_structure, H, P, device
log_prior_prec_init, prior_structure, H, P, device, dtype
)
log_prior_prec.requires_grad = True
hyperparameters.append(log_prior_prec)
Expand All @@ -182,7 +183,9 @@ def marglik_training(
elif likelihood == Likelihood.REGRESSION:
criterion = MSELoss(reduction="mean")
log_sigma_noise_init = np.log(sigma_noise_init)
log_sigma_noise = log_sigma_noise_init * torch.ones(1, device=device)
log_sigma_noise = log_sigma_noise_init * torch.ones(
1, device=device, dtype=dtype
)
log_sigma_noise.requires_grad = True
hyperparameters.append(log_sigma_noise)

Expand Down
8 changes: 5 additions & 3 deletions laplace/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
-------
X_add_scalar : torch.Tensor
"""
indices = torch.LongTensor([[i, i] for i in range(X.shape[0])], device=X.device)
indices = torch.LongTensor(
[[i, i] for i in range(X.shape[0])], device=X.device, dtype=X.dtype
)
values = X.new_ones(X.shape[0]).mul(value)
return X.index_put(tuple(indices.t()), values, accumulate=True)

Expand Down Expand Up @@ -278,10 +280,10 @@ def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) -> torch.
"""
trainable_params = [p for p in model.parameters() if p.requires_grad]
theta = parameters_to_vector(trainable_params)
device, P = theta.device, len(theta)
device, dtype, P = theta.device, theta.dtype, len(theta)
assert prior_prec.ndim == 1
if len(prior_prec) == 1: # scalar
return torch.ones(P, device=device) * prior_prec
return torch.ones(P, device=device, dtype=dtype) * prior_prec
elif len(prior_prec) == P: # full diagonal
return prior_prec.to(device)
else:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_functional_laplace_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,14 @@ def test_dtype(laplace, backend, dtype):
)
la.fit(dataloader)

assert la.L is not None
assert la.L.dtype == dtype

assert la.Sigma_inv is not None
assert la.Sigma_inv.dtype == dtype

# y_pred, y_var = la(X)
# assert y_pred.dtype == dtype
# assert y_var.dtype == dtype
y_pred, y_var = la(X)
assert y_pred.dtype == dtype
assert y_var.dtype == dtype
except (ValueError, RuntimeError, SystemExit):
pass

0 comments on commit 43ac1ec

Please sign in to comment.