Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: more flexible training losses #357

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,6 @@
sae_in = output.sae_in
sae_out = output.sae_out
feature_acts = output.feature_acts
mse_loss = output.mse_loss
l1_loss = output.l1_loss
ghost_grad_loss = output.ghost_grad_loss
loss = output.loss.item()

# metrics for currents acts
Expand All @@ -302,10 +299,6 @@

log_dict = {
# losses
"losses/mse_loss": mse_loss,
"losses/l1_loss": l1_loss
/ self.current_l1_coefficient, # normalize by l1 coefficient
"losses/auxiliary_reconstruction_loss": output.auxiliary_reconstruction_loss,
"losses/overall_loss": loss,
# variance explained
"metrics/explained_variance": explained_variance.mean().item(),
Expand All @@ -318,12 +311,16 @@
"details/current_l1_coefficient": self.current_l1_coefficient,
"details/n_training_tokens": n_training_tokens,
}
# Log ghost grad if we're using them
if self.cfg.use_ghost_grads:
if isinstance(ghost_grad_loss, torch.Tensor):
ghost_grad_loss = ghost_grad_loss.item()

log_dict["losses/ghost_grad_loss"] = ghost_grad_loss
for loss_name, loss_value in output.losses.items():
loss_item = _unwrap_item(loss_value)
# special case for l1 loss, which we normalize by the l1 coefficient
if loss_name == "l1_loss":
log_dict[f"losses/{loss_name}"] = (
loss_item / self.current_l1_coefficient
)
log_dict[f"losses/raw_{loss_name}"] = loss_item
else:
log_dict[f"losses/{loss_name}"] = loss_item

return log_dict

Expand Down Expand Up @@ -407,9 +404,11 @@
def _update_pbar(self, step_output: TrainStepOutput, pbar: tqdm, update_interval: int = 100): # type: ignore

if self.n_training_steps % update_interval == 0:
pbar.set_description(
f"{self.n_training_steps}| MSE Loss {step_output.mse_loss:.3f} | L1 {step_output.l1_loss:.3f}"
loss_strs = " | ".join(

Check warning on line 407 in sae_lens/training/sae_trainer.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/sae_trainer.py#L407

Added line #L407 was not covered by tests
f"{loss_name}: {_unwrap_item(loss_value):.5f}"
for loss_name, loss_value in step_output.losses.items()
)
pbar.set_description(f"{self.n_training_steps}| {loss_strs}")

Check warning on line 411 in sae_lens/training/sae_trainer.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/training/sae_trainer.py#L411

Added line #L411 was not covered by tests
pbar.update(update_interval * self.cfg.train_batch_size_tokens)

def _begin_finetuning_if_needed(self):
Expand All @@ -430,3 +429,7 @@
param.requires_grad = False

self.finetuning = True


def _unwrap_item(item: float | torch.Tensor) -> float:
return item.item() if isinstance(item, torch.Tensor) else item
35 changes: 14 additions & 21 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,7 @@ class TrainStepOutput:
sae_out: torch.Tensor
feature_acts: torch.Tensor
loss: torch.Tensor # we need to call backwards on this
mse_loss: float
l1_loss: float
ghost_grad_loss: float
auxiliary_reconstruction_loss: float = 0.0
losses: dict[str, float | torch.Tensor]


@dataclass(kw_only=True)
Expand Down Expand Up @@ -371,9 +368,7 @@ def training_forward_pass(
per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
mse_loss = per_item_mse_loss.sum(dim=-1).mean()

l1_loss = torch.tensor(0.0, device=sae_in.device)
aux_reconstruction_loss = torch.tensor(0.0, device=sae_in.device)
ghost_grad_loss = torch.tensor(0.0, device=sae_in.device)
losses: dict[str, float | torch.Tensor] = {}

if self.cfg.architecture == "gated":
# Gated SAE Loss Calculation
Expand All @@ -396,13 +391,15 @@ def training_forward_pass(
aux_reconstruction_loss = torch.sum(
(via_gate_reconstruction - sae_in) ** 2, dim=-1
).mean()

loss = mse_loss + l1_loss + aux_reconstruction_loss
losses["auxiliary_reconstruction_loss"] = aux_reconstruction_loss
losses["l1_loss"] = l1_loss
elif self.cfg.architecture == "jumprelu":
threshold = torch.exp(self.log_threshold)
l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
l1_loss = (current_l1_coefficient * l0).mean()
loss = mse_loss + l1_loss
l0_loss = (current_l1_coefficient * l0).mean()
loss = mse_loss + l0_loss
losses["l0_loss"] = l0_loss
else:
# default SAE sparsity loss
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
Expand All @@ -411,8 +408,7 @@ def training_forward_pass(
) # sum over the feature dimension

l1_loss = (current_l1_coefficient * sparsity).mean()
loss = mse_loss + l1_loss + ghost_grad_loss

loss = mse_loss + l1_loss
if (
self.cfg.use_ghost_grads
and self.training
Expand All @@ -425,21 +421,18 @@ def training_forward_pass(
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
loss = loss + ghost_grad_loss
losses["ghost_grad_loss"] = ghost_grad_loss
loss = loss + ghost_grad_loss
losses["l1_loss"] = l1_loss

losses["mse_loss"] = mse_loss

return TrainStepOutput(
sae_in=sae_in,
sae_out=sae_out,
feature_acts=feature_acts,
loss=loss,
mse_loss=mse_loss.item(),
l1_loss=l1_loss.item(),
ghost_grad_loss=(
ghost_grad_loss.item()
if isinstance(ghost_grad_loss, torch.Tensor)
else ghost_grad_loss
),
auxiliary_reconstruction_loss=aux_reconstruction_loss.item(),
losses=losses,
)

def calculate_ghost_grad_loss(
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/test_jumprelu_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def test_jumprelu_sae_training_forward_pass():

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae)
assert pytest.approx(train_step_output.loss.detach(), rel=1e-3) == (
train_step_output.mse_loss + train_step_output.l1_loss
assert (
pytest.approx(train_step_output.loss.detach(), rel=1e-3)
== (
train_step_output.losses["mse_loss"] + train_step_output.losses["l0_loss"]
).item() # type: ignore
)

expected_mse_loss = (
Expand All @@ -57,4 +60,6 @@ def test_jumprelu_sae_training_forward_pass():
.float()
)

assert pytest.approx(train_step_output.mse_loss) == expected_mse_loss
assert (
pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore
)
15 changes: 11 additions & 4 deletions tests/unit/training/test_gated_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def test_gated_sae_loss():
).mean()

expected_loss = (
train_step_output.mse_loss + preactivation_l1_loss + aux_reconstruction_loss
train_step_output.losses["mse_loss"]
+ preactivation_l1_loss
+ aux_reconstruction_loss
)
assert (
pytest.approx(train_step_output.loss.item(), rel=1e-3) == expected_loss.item()
Expand Down Expand Up @@ -129,9 +131,14 @@ def test_gated_sae_training_forward_pass():
# Detach the loss tensor and convert to numpy for comparison
detached_loss = train_step_output.loss.detach().cpu().numpy()
expected_loss = (
train_step_output.mse_loss
+ train_step_output.l1_loss
+ train_step_output.auxiliary_reconstruction_loss
(
train_step_output.losses["mse_loss"]
+ train_step_output.losses["l1_loss"]
+ train_step_output.losses["auxiliary_reconstruction_loss"]
)
.detach() # type: ignore
.cpu()
.numpy()
)

assert pytest.approx(detached_loss, rel=1e-3) == expected_loss
15 changes: 9 additions & 6 deletions tests/unit/training/test_sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_train_step__output_looks_reasonable(trainer: SAETrainer) -> None:
assert output.sae_out.shape == output.sae_in.shape
assert output.feature_acts.shape == (4, 128) # batch_size, d_sae
# ghots grads shouldn't trigger until dead_feature_window, which hasn't been reached yet
assert output.ghost_grad_loss == 0
assert output.losses.get("ghost_grad_loss", 0) == 0
assert trainer.n_frac_active_tokens == 4
assert trainer.act_freq_scores.sum() > 0 # at least SOME acts should have fired
assert torch.allclose(
Expand Down Expand Up @@ -169,9 +169,11 @@ def test_build_train_step_log_dict(trainer: SAETrainer) -> None:
sae_out=torch.tensor([[0, 0], [0, 2], [0.5, 1]]).float(),
feature_acts=torch.tensor([[0, 0, 0, 1], [1, 0, 0, 1], [1, 0, 1, 1]]).float(),
loss=torch.tensor(0.5),
mse_loss=0.25,
l1_loss=0.1,
ghost_grad_loss=0.15,
losses={
"mse_loss": 0.25,
"l1_loss": 0.1,
"ghost_grad_loss": 0.15,
},
)

# we're relying on the trainer only for some of the metrics here
Expand All @@ -183,9 +185,10 @@ def test_build_train_step_log_dict(trainer: SAETrainer) -> None:
assert log_dict == {
"losses/mse_loss": 0.25,
# l1 loss is scaled by l1_coefficient
"losses/l1_loss": train_output.l1_loss / trainer.cfg.l1_coefficient,
"losses/auxiliary_reconstruction_loss": 0.0,
"losses/l1_loss": train_output.losses["l1_loss"] / trainer.cfg.l1_coefficient,
"losses/raw_l1_loss": train_output.losses["l1_loss"],
"losses/overall_loss": 0.5,
"losses/ghost_grad_loss": 0.15,
"metrics/explained_variance": 0.75,
"metrics/explained_variance_std": 0.25,
"metrics/l0": 2.0,
Expand Down
41 changes: 27 additions & 14 deletions tests/unit/training/test_sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,16 @@ def test_sae_forward(training_sae: TrainingSAE):

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, d_sae)
assert pytest.approx(train_step_output.loss.detach(), rel=1e-3) == (
train_step_output.mse_loss
+ train_step_output.l1_loss
+ train_step_output.ghost_grad_loss
assert (
pytest.approx(train_step_output.loss.detach(), rel=1e-3)
== (
train_step_output.losses["mse_loss"]
+ train_step_output.losses["l1_loss"]
+ train_step_output.losses.get("ghost_grad_loss", 0.0)
)
.detach() # type: ignore
.cpu()
.numpy()
)

expected_mse_loss = (
Expand All @@ -168,7 +174,9 @@ def test_sae_forward(training_sae: TrainingSAE):
.float()
)

assert pytest.approx(train_step_output.mse_loss) == expected_mse_loss
assert (
pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore
)

if not training_sae.cfg.scale_sparsity_penalty_by_decoder_norm:
expected_l1_loss = train_step_output.feature_acts.sum(dim=1).mean(dim=(0,))
Expand All @@ -179,7 +187,7 @@ def test_sae_forward(training_sae: TrainingSAE):
.mean()
)
assert (
pytest.approx(train_step_output.l1_loss, rel=1e-3)
pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore
== training_sae.cfg.l1_coefficient * expected_l1_loss.detach().float()
)

Expand All @@ -203,7 +211,7 @@ def test_sae_forward_with_mse_loss_norm(

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, d_sae)
assert train_step_output.ghost_grad_loss == 0.0
assert "ghost_grad_loss" not in train_step_output.losses

x_centred = x - x.mean(dim=0, keepdim=True)
expected_mse_loss = (
Expand All @@ -217,12 +225,17 @@ def test_sae_forward_with_mse_loss_norm(
.item()
)

assert pytest.approx(train_step_output.mse_loss) == expected_mse_loss
assert pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore

assert pytest.approx(train_step_output.loss.detach(), rel=1e-3) == (
train_step_output.mse_loss
+ train_step_output.l1_loss
+ train_step_output.ghost_grad_loss
assert (
pytest.approx(train_step_output.loss.detach(), rel=1e-3)
== (
train_step_output.losses["mse_loss"]
+ train_step_output.losses["l1_loss"]
+ train_step_output.losses.get("ghost_grad_loss", 0.0)
)
.detach() # type: ignore
.numpy()
)

if not training_sae.cfg.scale_sparsity_penalty_by_decoder_norm:
Expand All @@ -234,7 +247,7 @@ def test_sae_forward_with_mse_loss_norm(
.mean()
)
assert (
pytest.approx(train_step_output.l1_loss, rel=1e-3)
pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore
== training_sae.cfg.l1_coefficient * expected_l1_loss.detach().float()
)

Expand All @@ -255,7 +268,7 @@ def test_SparseAutoencoder_forward_ghost_grad_loss_non_zero(
).bool(), # all neurons are dead.
)

assert train_step_output.ghost_grad_loss != 0.0
assert train_step_output.losses["ghost_grad_loss"] != 0.0


def test_calculate_ghost_grad_loss(
Expand Down