Skip to content

Commit

Permalink
happy with hyperpars on benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Nov 30, 2023
1 parent f52c7bb commit 836298a
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ default-docstring-type = numpy
max-line-length = 88

[MESSAGES CONTROL]
disable = C0330, C0326, C0199, C0411
disable = C0330, C0326, C0199, C0411, C103
66 changes: 46 additions & 20 deletions sae_training/SAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import einops
import torch
from torch import nn
from jaxtyping import Float, Int
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from transformer_lens.hook_points import HookedRootModule, HookPoint


Expand Down Expand Up @@ -93,22 +95,46 @@ def forward(self, x, return_mode: Literal["sae_out", "hidden_post", "both"]="bot
else:
raise ValueError(f"Unexpected {return_mode=}")

def reinit_neurons(self, indices):
new_W_enc = torch.nn.init.kaiming_uniform_(
torch.empty(
self.d_in, indices.shape[0], dtype=self.dtype, device=self.device
)
) * self.cfg["resample_factor"]
new_b_enc = torch.zeros(
indices.shape[0], dtype=self.dtype, device=self.device
)
new_W_dec = torch.nn.init.kaiming_uniform_(
torch.empty(
indices.shape[0], self.d_in, dtype=self.dtype, device=self.get_test_lossevice
)
)
self.W_enc.data[:, indices] = new_W_enc
self.b_enc.data[indices] = new_b_enc
self.W_dec.data[indices, :] = new_W_dec
self.W_dec /= torch.norm(self.W_dec, dim=1, keepdim=True)

@torch.no_grad()
def resample_neurons(
self,
x: Float[Tensor, "batch_size n_hidden"],
frac_active_in_window: Float[Tensor, "window n_hidden_ae"],
neuron_resample_scale: float,
) -> None:
'''
Resamples neurons that have been dead for `dead_neuron_window` steps, according to `frac_active`.
'''
sae_out = self.forward(x, return_mode="sae_out")
per_token_l2_loss = (sae_out - x).pow(2).sum(dim=-1).squeeze()

# Find the dead neurons in this instance. If all neurons are alive, continue
is_dead = (frac_active_in_window.sum(0) < 1e-8)
dead_neurons = torch.nonzero(is_dead).squeeze(-1)
alive_neurons = torch.nonzero(~is_dead).squeeze(-1)
n_dead = dead_neurons.numel()

if n_dead == 0:
return # If there are no dead neurons, we don't need to resample neurons

# Compute L2 loss for each element in the batch
# TODO: Check whether we need to go through more batches as features get sparse to find high l2 loss examples.
if per_token_l2_loss.max() < 1e-6:
return # If we have zero reconstruction loss, we don't need to resample neurons

# Draw `n_hidden_ae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss
distn = Categorical(probs = per_token_l2_loss / per_token_l2_loss.sum())
replacement_indices = distn.sample((n_dead,)) # shape [n_dead]

# Index into the batch of hidden activations to get our replacement values
replacement_values = (x - self.b_dec)[replacement_indices] # shape [n_dead n_input_ae]

# Get the norm of alive neurons (or 1.0 if there are no alive neurons)
W_enc_norm_alive_mean = 1.0 if len(alive_neurons) == 0 else self.W_enc[:, alive_neurons].norm(dim=0).mean().item()

# Use this to renormalize the replacement values
replacement_values = (replacement_values / (replacement_values.norm(dim=1, keepdim=True) + 1e-8)) * W_enc_norm_alive_mean * neuron_resample_scale

# Lastly, set the new weights & biases
self.W_enc.data[:, dead_neurons] = replacement_values.T.squeeze(1)
self.b_enc.data[dead_neurons] = 0.0
15 changes: 11 additions & 4 deletions sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@ class SAEToyModelRunnerConfig:
# Relu Model Training Parameters
model_training_steps: int = 10_000
# SAE Parameters
expansion_factor: int = 4
d_sae: int = 5
# Training Parameters
n_sae_training_tokens: int = 25_000
l1_coefficient: float = 1e-3
lr: float = 3e-4
train_batch_size: int = 32 # Shouldn't be as big as the batch size for language models
train_batch_size: int = 1024 # Shouldn't be as big as the batch size for language models
train_epochs: int = 10
feature_sampling_window: int = 100
feature_reinit_scale: float = 0.2
dead_feature_threshold: float = 1e-8
# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_toy_model"
wandb_entity: str = None
wandb_log_frequency: int = 50
# Misc
device: str = "cpu"
seed: int = 42
Expand All @@ -43,8 +47,6 @@ class SAEToyModelRunnerConfig:

def __post_init__(self):
self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE
self.d_sae = self.n_hidden * self.expansion_factor


def toy_model_sae_runner(cfg):
'''
Expand Down Expand Up @@ -83,12 +85,17 @@ def toy_model_sae_runner(cfg):
wandb.init(project="sae-training-test", config=cfg)

sae = train_sae(
model, # need model so we can do evals for neuron resampling
sae,
hidden.detach().squeeze(),
use_wandb=cfg.log_to_wandb,
l1_coeff=cfg.l1_coefficient,
batch_size=cfg.train_batch_size,
n_epochs=cfg.train_epochs,
feature_sampling_window=cfg.feature_sampling_window,
feature_reinit_scale=cfg.feature_reinit_scale,
dead_feature_threshold=cfg.dead_feature_threshold,
wandb_log_frequency=cfg.wandb_log_frequency,
)

if cfg.log_to_wandb:
Expand Down
93 changes: 65 additions & 28 deletions sae_training/train_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,98 @@
import wandb
from sae_training.activation_store import ActivationStore
from sae_training.SAE import SAE
from sae_training.toy_models import Model as ToyModel


#%%
def train_sae(sae: SAE,
def train_sae(model: ToyModel,
sae: SAE,
activation_store: ActivationStore,
n_epochs: int = 10,
batch_size: int = 32,
batch_size: int = 1024,
l1_coeff: float = 0.001,
feature_sampling_window: int = 100, # how many training steps between resampling the features / considiring neurons dead
feature_reinit_scale: float = 0.2, # how much to scale the resampled features by
dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead
use_wandb: bool = False,
wandb_log_freq: int = 10,):
wandb_log_frequency: int = 50,):
"""
Takes an SAE and a bunch of activations and does a bunch of training steps
"""

dataloader = DataLoader(activation_store, batch_size=batch_size, shuffle=True)

optimizer = torch.optim.Adam(sae.parameters())
frac_active_list = [] # track active features

sae.train()
n_training_steps = 0
for epoch in range(n_epochs):
pbar = tqdm(dataloader)
for step, batch in enumerate(pbar):
optimizer.zero_grad()

sae_out, hidden_post = sae(batch)

# Make sure the W_dec is still zero-norm
sae.W_dec.data /= (torch.norm(sae.W_dec.data, dim=1, keepdim=True) + 1e-8)

# Resample dead neurons
if (feature_sampling_window is not None) and ((step + 1) % feature_sampling_window == 0):

# Get the fraction of neurons active in the previous window
frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0)

# Compute batch of hidden activations which we'll use in resampling
resampling_batch = model.generate_batch(batch_size)

# Our version of running the model
hidden = einops.einsum(
resampling_batch,
model.W,
"batch_size instances features, instances hidden features -> batch_size instances hidden",
)

# Resample
sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale)


# Update learning rate here if using scheduler.

# Forward and Backward Passes
optimizer.zero_grad()
sae_out, feature_acts = sae(batch)
# loss = reconstruction MSE + L1 regularization
mse_loss = ((sae_out - batch)**2).mean()
l1_loss = torch.abs(hidden_post).sum()
l1_loss = torch.abs(feature_acts).sum()
loss = mse_loss + l1_coeff * l1_loss

with torch.no_grad():

batch_size = batch.shape[0]
frac_feature_activation = (hidden_post > 0).float().mean(0)
log_frac_feature_activation = torch.log(frac_feature_activation + 1e-8)
n_dead_features = (frac_feature_activation > 0).sum()
l0 = ((hidden_post != 0) / batch_size).sum()
l2_norm = torch.norm(hidden_post, dim=1).mean()
# Calculate the sparsities, and add it to a list
frac_active = einops.reduce(
(feature_acts.abs() > dead_feature_threshold).float(),
"batch_size hidden_ae -> hidden_ae", "mean")
frac_active_list.append(frac_active)

batch_size = batch.shape[0]
log_frac_feature_activation = torch.log(frac_active + 1e-8)
n_dead_features = (frac_active < dead_feature_threshold).sum()

if use_wandb and (step % wandb_log_freq == 0):
wandb.log({
"losses/mse_loss": mse_loss.item(),
"losses/l1_loss": l1_loss.item(),
"losses/overall_loss": loss.item(),
"metrics/l0": l0.item(),
"metrics/l2": l2_norm.item(),
"metrics/feature_density_histogram": wandb.Histogram(log_frac_feature_activation.tolist()),
"metrics/n_dead_features": n_dead_features,
}, step=n_training_steps)
l0 = (feature_acts > 0).float().mean()
l2_norm = torch.norm(feature_acts, dim=1).mean()

pbar.set_description(f"{epoch}/{step}| MSE Loss {mse_loss.item():.3f} | L0 {l0.item():.3f} | n_dead_features {n_dead_features}")

if use_wandb and ((step + 1) % wandb_log_frequency == 0):
wandb.log({
"losses/mse_loss": mse_loss.item(),
"losses/l1_loss": batch_size*l1_loss.item(),
"losses/overall_loss": loss.item(),
"metrics/l0": l0.item(),
"metrics/l2": l2_norm.item(),
# "metrics/feature_density_histogram": wandb.Histogram(log_frac_feature_activation.tolist()),
"metrics/n_dead_features": n_dead_features,
"metrics/n_alive_features": sae.d_sae - n_dead_features,
}, step=n_training_steps)

pbar.set_description(f"{epoch}/{step}| MSE Loss {mse_loss.item():.3f} | L0 {l0.item():.3f} | n_dead_features {n_dead_features}")

loss.backward()

Expand All @@ -81,10 +120,8 @@ def train_sae(sae: SAE,

optimizer.step()

# Make sure the W_dec is still zero-norm
with torch.no_grad():
sae.W_dec.data /= (torch.norm(sae.W_dec.data, dim=1, keepdim=True) + 1e-8)



n_training_steps += 1


Expand Down
28 changes: 18 additions & 10 deletions tests/benchmark/test_sae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,26 @@


def test_toy_model_sae_runner():

cfg = SAEToyModelRunnerConfig(
n_features = 5,
n_hidden = 2,
n_correlated_pairs = 0,
n_anticorrelated_pairs = 0,
feature_probability = 0.025,
model_training_steps = 10_000,
n_sae_training_tokens = 50_000,
log_to_wandb = True,
n_features=5,
n_hidden=2,
n_correlated_pairs=0,
n_anticorrelated_pairs=0,
feature_probability=0.025,
# SAE Parameters
d_sae=5,
l1_coefficient=0.005,
# SAE Train Config
train_batch_size=1024,
feature_sampling_window=3_000,
feature_reinit_scale=0.5,
model_training_steps=10_000,
n_sae_training_tokens=1024*10_000,
train_epochs=1,
log_to_wandb=False,
wandb_log_frequency=5,
)

trained_sae = toy_model_sae_runner(cfg)

assert trained_sae is not None
assert trained_sae is not None

0 comments on commit 836298a

Please sign in to comment.