Skip to content

Commit

Permalink
Merge branch 'main' into fix_np
Browse files Browse the repository at this point in the history
  • Loading branch information
hijohnnylin committed Apr 15, 2024
2 parents f8fb3ef + feca408 commit 6658392
Show file tree
Hide file tree
Showing 4 changed files with 5,225 additions and 36 deletions.
13 changes: 7 additions & 6 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Optional, cast

import torch

import wandb

DTYPE_MAP = {
Expand Down Expand Up @@ -33,6 +34,12 @@ class LanguageModelSAERunnerConfig:

# SAE Parameters
d_in: int = 512
d_sae: Optional[int] = None
b_dec_init_method: str = "geometric_median"
expansion_factor: int | list[int] = 4
from_pretrained_path: Optional[str] = None
apply_b_dec_to_input: bool = True
decoder_orthogonal_init: bool = True

# Activation Store Parameters
n_batches_in_buffer: int = 20
Expand All @@ -46,12 +53,6 @@ class LanguageModelSAERunnerConfig:
dtype: str | torch.dtype = "float32" # type: ignore #
prepend_bos: bool = True

# SAE Parameters
b_dec_init_method: str = "geometric_median"
expansion_factor: int | list[int] = 4
from_pretrained_path: Optional[str] = None
d_sae: Optional[int] = None

# Training Parameters
mse_loss_normalization: Optional[str] = None
l1_coefficient: float | list[float] = 1e-3
Expand Down
16 changes: 7 additions & 9 deletions sae_lens/training/session_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sae_group import SparseAutoencoderDictionary
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


class LMSparseAutoencoderSessionloader:
Expand Down Expand Up @@ -46,17 +45,16 @@ def load_pretrained_sae(
"""

# load the SAE
sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device
sparse_autoencoders = SparseAutoencoderDictionary.load_from_pretrained(
path, device
)
first_sparse_autoencoder_cfg = next(iter(sparse_autoencoders))[1].cfg

# load the model, SAE and activations loader with it.
session_loader = cls(sparse_autoencoder.cfg)
model, sae_group, activations_loader = (
session_loader.load_sae_training_group_session()
)
session_loader = cls(first_sparse_autoencoder_cfg)
model, _, activations_loader = session_loader.load_sae_training_group_session()

return model, sae_group, activations_loader
return model, sparse_autoencoders, activations_loader

def get_model(self, model_name: str) -> HookedTransformer:
"""
Expand Down
5 changes: 4 additions & 1 deletion sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __init__(
)
)

if self.cfg.decoder_orthogonal_init:
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T

with torch.no_grad():
# Anthropic normalize this to have unit columns
self.set_decoder_norm_to_unit_norm()
Expand All @@ -106,7 +109,7 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)
# move x to correct dtype
x = x.to(self.dtype)
sae_in = self.hook_sae_in(
x - self.b_dec
x - (self.b_dec * self.cfg.apply_b_dec_to_input)
) # Remove decoder bias as per Anthropic

hidden_pre = self.hook_hidden_pre(
Expand Down
Loading

0 comments on commit 6658392

Please sign in to comment.