Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multidim outputs #266

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address weight-sharing dims in ASDL Jacobian
  • Loading branch information
wiseodd committed Nov 20, 2024
commit 6c310f1895b4105711caf8476f50c97e56e9736d
9 changes: 5 additions & 4 deletions examples/lm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BATCH_SIZE = 4 # B
SEQ_LENGTH = 6 # L
EMBED_DIM = 8 # D
OUTPUT_SIZE = 2 # K
INPUT_KEY = "input"
OUTPUT_KEY = "output"

Expand All @@ -18,18 +19,18 @@ class Model(nn.Module):
def __init__(self):
super().__init__()
self.attn = nn.MultiheadAttention(EMBED_DIM, num_heads=1)
self.final_layer = nn.Linear(EMBED_DIM, 1)
self.final_layer = nn.Linear(EMBED_DIM, OUTPUT_SIZE)

def forward(self, x):
x = x[INPUT_KEY].view(-1, SEQ_LENGTH, EMBED_DIM) # (B, L, D)
out = self.attn(x, x, x, need_weights=False)[0] # (B, L, D)
return self.final_layer(out) # (B, L, 1)
return self.final_layer(out) # (B, L, K)


ds = TensorDict(
{
INPUT_KEY: torch.randn((100, SEQ_LENGTH, EMBED_DIM)),
OUTPUT_KEY: torch.randn((100, SEQ_LENGTH, 1)),
OUTPUT_KEY: torch.randn((100, SEQ_LENGTH, OUTPUT_SIZE)),
},
batch_size=[100],
) # simulates a dataset
Expand All @@ -54,7 +55,7 @@ def forward(self, x):
backend=AsdlEF,
dict_key_x=INPUT_KEY,
dict_key_y=OUTPUT_KEY,
enable_backprop=True,
enable_backprop=False, # True => functorch Jacobian, False => ASDL Jacobian
)
la.fit(dl)

Expand Down
33 changes: 19 additions & 14 deletions laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from collections.abc import MutableMapping
from typing import Any
from typing import Any, Sequence

import torch
from asdl.fisher import FisherConfig, get_fisher_maker
Expand Down Expand Up @@ -83,16 +83,23 @@ def closure():
)
return f

Ji, f = batch_gradient(
self.model,
closure,
return_outputs=True,
batch_size=self._get_batch_size(x),
)
Ji, f = batch_gradient(self.model, closure, return_outputs=True)

if self.subnetwork_indices is not None:
Ji = Ji[:, self.subnetwork_indices]

Js.append(Ji)
Js = torch.stack(Js, dim=1)

Js = torch.stack(Js, dim=1) # (prod_batch_dim, n_outputs, n_params)

if Js.ndim != 3:
raise AttributeError("ASDL Jacobian must be a 3-dim tensor.")

if Js.shape[0] != self._get_batch_size(x)[0]:
# There are intermediate "weight-sharing" dimensions.
# So, reshape `prod_batch_dim` into the original, individual batch dims.
Js = Js.reshape(*self._get_batch_size(x), *Js.shape[-2:])

return Js, f

def gradients(
Expand Down Expand Up @@ -120,9 +127,7 @@ def closure():
loss.backward()
return loss

Gs, loss = batch_gradient(
self.model, closure, return_outputs=True, batch_size=self._get_batch_size(x)
)
Gs, loss = batch_gradient(self.model, closure, return_outputs=True)
if self.subnetwork_indices is not None:
Gs = Gs[:, self.subnetwork_indices]
return Gs, loss
Expand Down Expand Up @@ -253,16 +258,16 @@ def kron(
def _get_batch_size(
self,
x: torch.Tensor | MutableMapping[str, torch.Tensor | Any],
) -> int | None:
) -> Sequence[int]:
"""
ASDL assumes that all leading dimensions are the batch size by default (batch_size = None).
Here, we want to specify that only the first dimension is the actual batch size.
This is the case for LLMs.
"""
if isinstance(x, MutableMapping):
return x[self.dict_key_x].shape[0]
return x[self.dict_key_x].shape[:-1]
else:
return None # Use ASDL default behavior
return x.shape[:-1]


class AsdlHessian(AsdlInterface):
Expand Down