Skip to content

Commit

Permalink
Merge pull request #40 from chanind/refactor-train-sae
Browse files Browse the repository at this point in the history
Refactor train SAE and adding unit tests
  • Loading branch information
jbloomAus authored Mar 22, 2024
2 parents bcb9a52 + 0acdcb3 commit 5aa0b11
Show file tree
Hide file tree
Showing 3 changed files with 714 additions and 250 deletions.
20 changes: 18 additions & 2 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gzip
import os
import pickle
from typing import Any
from typing import Any, NamedTuple

import einops
import torch
Expand All @@ -16,6 +16,15 @@
from sae_training.geometric_median import compute_geometric_median


class ForwardOutput(NamedTuple):
sae_out: torch.Tensor
feature_acts: torch.Tensor
loss: torch.Tensor
mse_loss: torch.Tensor
l1_loss: torch.Tensor
ghost_grad_loss: torch.Tensor


class SparseAutoencoder(HookedRootModule):
""" """

Expand Down Expand Up @@ -138,7 +147,14 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)
l1_loss = self.l1_coefficient * sparsity
loss = mse_loss + l1_loss + mse_loss_ghost_resid

return sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid
return ForwardOutput(
sae_out=sae_out,
feature_acts=feature_acts,
loss=loss,
mse_loss=mse_loss,
l1_loss=l1_loss,
ghost_grad_loss=mse_loss_ghost_resid,
)

@torch.no_grad()
def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
Expand Down
Loading

0 comments on commit 5aa0b11

Please sign in to comment.