Skip to content

Commit

Permalink
Merge branch 'main' into fix_np
Browse files Browse the repository at this point in the history
  • Loading branch information
hijohnnylin committed Apr 15, 2024
2 parents e87788d + 2b4baed commit f8fb3ef
Show file tree
Hide file tree
Showing 14 changed files with 715 additions and 18 deletions.
28 changes: 28 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,34 @@



## v0.3.0 (2024-04-15)

### Feature

* feat: add basic tutorial for training saes ([`1847280`](https://github.com/jbloomAus/SAELens/commit/18472800481dbe584e5fab8533ac47a1ee39a062))


## v0.2.2 (2024-04-15)

### Fix

* fix: dense batch dim mse norm optional ([`8018bc9`](https://github.com/jbloomAus/SAELens/commit/8018bc939811bdb7e59c999e055c26401af6d0d2))

### Unknown

* format ([`c359c27`](https://github.com/jbloomAus/SAELens/commit/c359c272ae4d5b1e25da5333c4beff99e924532c))

* make dense_batch_mse_normalization optional ([`c41774e`](https://github.com/jbloomAus/SAELens/commit/c41774e5cfaeb195e3320e9e3fc93d60d921337d))

* add warning in run script ([`9a772ca`](https://github.com/jbloomAus/SAELens/commit/9a772ca6da155b5e97bc3109da74457f5addfbfd))

* update sae loading code ([`356a8ef`](https://github.com/jbloomAus/SAELens/commit/356a8efba06e4f453d2f15afe9171b71d780819a))

* add device override to session loader ([`96b1e12`](https://github.com/jbloomAus/SAELens/commit/96b1e120d78f5f05cd94aec7a763bc14849aa1d3))

* update readme ([`5cd5652`](https://github.com/jbloomAus/SAELens/commit/5cd5652a4b19ba985d20c229b6a92d17774bc6b9))


## v0.2.1 (2024-04-13)

### Fix
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sae-lens"
version = "0.2.1"
version = "0.3.0"
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
authors = ["Joseph Bloom"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.1"
__version__ = "0.3.0"

from .training.activations_store import ActivationsStore
from .training.cache_activations_runner import cache_activations_runner
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_config_classes import (
ActsHistogramConfig,
Column,
Expand All @@ -23,7 +24,6 @@
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader


Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Optional, cast

import torch

import wandb

DTYPE_MAP = {
Expand Down Expand Up @@ -54,6 +53,7 @@ class LanguageModelSAERunnerConfig:
d_sae: Optional[int] = None

# Training Parameters
mse_loss_normalization: Optional[str] = None
l1_coefficient: float | list[float] = 1e-3
lp_norm: float | list[float] = 1
lr: float | list[float] = 3e-4
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import pandas as pd
import torch
import wandb
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

Expand Down
1 change: 1 addition & 0 deletions sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, cast

import wandb

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader

Expand Down
23 changes: 15 additions & 8 deletions sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)
)

# add config for whether l2 is normalized:
per_item_mse_loss = _per_item_mse_loss_with_target_norm(sae_out, x)
per_item_mse_loss = _per_item_mse_loss_with_target_norm(
sae_out, x, self.cfg.mse_loss_normalization
)
ghost_grad_loss = torch.tensor(0.0, dtype=self.dtype, device=self.device)
# gate on config and training so evals is not slowed down.
if (
Expand Down Expand Up @@ -354,7 +356,7 @@ def calculate_ghost_grad_loss(

# 3.
per_item_mse_loss_ghost_resid = _per_item_mse_loss_with_target_norm(
ghost_out, residual.detach()
ghost_out, residual.detach(), self.cfg.mse_loss_normalization
)
mse_rescaling_factor = (
per_item_mse_loss / (per_item_mse_loss_ghost_resid + 1e-6)
Expand All @@ -367,15 +369,20 @@ def calculate_ghost_grad_loss(


def _per_item_mse_loss_with_target_norm(
preds: torch.Tensor, target: torch.Tensor
preds: torch.Tensor,
target: torch.Tensor,
mse_loss_normalization: Optional[str] = None,
) -> torch.Tensor:
"""
Calculate MSE loss per item in the batch, without taking a mean.
Then, normalizes by the L2 norm of the centered target.
This normalization seems to improve performance.
"""
target_centered = target - target.mean(dim=0, keepdim=True)
normalization = target_centered.norm(dim=-1, keepdim=True)
return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
normalization + 1e-6
)
if mse_loss_normalization == "dense_batch":
target_centered = target - target.mean(dim=0, keepdim=True)
normalization = target_centered.norm(dim=-1, keepdim=True)
return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
normalization + 1e-6
)
else:
return torch.nn.functional.mse_loss(preds, target, reduction="none")
2 changes: 1 addition & 1 deletion sae_lens/training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import einops
import torch

import wandb

from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.toy_models import Config as ToyConfig
from sae_lens.training.toy_models import Model as ToyModel
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Any, cast

import torch
import wandb
from safetensors.torch import save_file
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.evals import run_evals
from sae_lens.training.geometric_median import compute_geometric_median
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, cast

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


Expand Down
45 changes: 44 additions & 1 deletion tests/unit/training/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,47 @@ def test_sparse_autoencoder_forward(sparse_autoencoder: SparseAutoencoder):
assert l1_loss.shape == ()
assert torch.allclose(loss, mse_loss + l1_loss)

expected_mse_loss = (torch.pow((sae_out - x.float()), 2)).mean()

assert torch.allclose(mse_loss, expected_mse_loss)
expected_l1_loss = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
assert torch.allclose(l1_loss, sparse_autoencoder.l1_coefficient * expected_l1_loss)

# check everything has the right dtype
assert sae_out.dtype == sparse_autoencoder.dtype
assert feature_acts.dtype == sparse_autoencoder.dtype
assert loss.dtype == sparse_autoencoder.dtype
assert mse_loss.dtype == sparse_autoencoder.dtype
assert l1_loss.dtype == sparse_autoencoder.dtype


def test_sparse_autoencoder_forward_with_mse_loss_norm(
sparse_autoencoder: SparseAutoencoder,
):
batch_size = 32
d_in = sparse_autoencoder.d_in
d_sae = sparse_autoencoder.d_sae
sparse_autoencoder.cfg.mse_loss_normalization = "dense_batch"

x = torch.randn(batch_size, d_in)
(
sae_out,
feature_acts,
loss,
mse_loss,
l1_loss,
_ghost_grad_loss,
) = sparse_autoencoder.forward(
x,
)

assert sae_out.shape == (batch_size, d_in)
assert feature_acts.shape == (batch_size, d_sae)
assert loss.shape == ()
assert mse_loss.shape == ()
assert l1_loss.shape == ()
assert torch.allclose(loss, mse_loss + l1_loss)

x_centred = x - x.mean(dim=0, keepdim=True)
expected_mse_loss = (
torch.pow((sae_out - x.float()), 2)
Expand Down Expand Up @@ -199,7 +240,9 @@ def test_per_item_mse_loss_with_norm_matches_original_implementation() -> None:
torch.pow((input - target.float()), 2)
/ (target_centered**2).sum(dim=-1, keepdim=True).sqrt()
)
sae_res = _per_item_mse_loss_with_target_norm(input, target)
sae_res = _per_item_mse_loss_with_target_norm(
input, target, mse_loss_normalization="dense_batch"
)
assert torch.allclose(orig_impl_res, sae_res, atol=1e-5)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/training/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import pytest
import torch
import wandb
from datasets import Dataset
from torch import Tensor
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.optim import get_scheduler
from sae_lens.training.sae_group import SparseAutoencoderDictionary
Expand Down
Loading

0 comments on commit f8fb3ef

Please sign in to comment.