Skip to content

Commit

Permalink
commit_various_things_in_progress
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Dec 9, 2023
1 parent 6f4030c commit 3843c39
Show file tree
Hide file tree
Showing 15 changed files with 6,919 additions and 233 deletions.
3 changes: 3 additions & 0 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@ def get_buffer(self, n_batches_in_buffer):
)

# Insert activations directly into pre-allocated buffer
pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")
for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens).to(self.cfg.device)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations

pbar.update(1)

new_buffer = new_buffer.reshape(-1, d_in)
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]
Expand Down
3 changes: 1 addition & 2 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@ class LanguageModelSAERunnerConfig:
context_size: int = 128

# Resampling protocol args
feature_sampling_method: str = "l2" # None, l2, or anthropic
feature_sampling_method: str = "l2" # None or l2
feature_sampling_window: int = 200
feature_reinit_scale: float = 0.2
dead_feature_window: int = 100 # unless this window is larger feature sampling,
dead_feature_threshold: float = 1e-8


# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
Expand Down
57 changes: 46 additions & 11 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def resample_neurons(
x: Float[Tensor, "batch_size n_hidden"],
feature_sparsity: Float[Tensor, "n_hidden_ae"],
neuron_resample_scale: float,
optimizer: torch.optim.Optimizer,
) -> None:
'''
Resamples neurons that have been dead for `dead_neuron_window` steps, according to `frac_active`.
Expand All @@ -109,7 +110,7 @@ def resample_neurons(
per_token_l2_loss = (sae_out - x).pow(2).sum(dim=-1).squeeze()

# Find the dead neurons in this instance. If all neurons are alive, continue
is_dead = (feature_sparsity < 1e-8)
is_dead = (feature_sparsity < self.cfg.dead_feature_threshold)
dead_neurons = torch.nonzero(is_dead).squeeze(-1)
alive_neurons = torch.nonzero(~is_dead).squeeze(-1)
n_dead = dead_neurons.numel()
Expand All @@ -122,26 +123,60 @@ def resample_neurons(
if per_token_l2_loss.max() < 1e-6:
return 0 # If we have zero reconstruction loss, we don't need to resample neurons

# Draw `n_hidden_ae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss
# Draw `n_hidden_ae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss squared
distn = Categorical(probs = per_token_l2_loss.pow(2) / (per_token_l2_loss.pow(2).sum()))
n_samples = n_dead#min(n_dead, feature_sparsity.shape[-1] // self.cfg.expansion_factor) # don't reinit more than 10% of neurons at a time
replacement_indices = distn.sample((n_samples,)) # shape [n_dead]
n_resampled_neurons = n_dead
replacement_indices = distn.sample((n_resampled_neurons,)) # shape [n_dead]

# Index into the batch of hidden activations to get our replacement values
replacement_values = (x - self.b_dec)[replacement_indices] # shape [n_dead n_input_ae]

# unit norm
replacement_values = (replacement_values / (replacement_values.norm(dim=1, keepdim=True) + 1e-8))

# St new decoder weights
self.W_dec.data[is_dead, :] = replacement_values

# Get the norm of alive neurons (or 1.0 if there are no alive neurons)
W_enc_norm_alive_mean = 1.0 if len(alive_neurons) == 0 else self.W_enc[:, alive_neurons].norm(dim=0).mean().item()

# Use this to renormalize the replacement values
replacement_values = (replacement_values / (replacement_values.norm(dim=1, keepdim=True) + 1e-8)) * W_enc_norm_alive_mean * neuron_resample_scale

# Lastly, set the new weights & biases
d_neurons_to_be_replaced = dead_neurons[:n_samples] # not restarting all!
self.W_enc.data[:, d_neurons_to_be_replaced] = replacement_values.T
self.b_enc.data[d_neurons_to_be_replaced] = 0.0
self.W_enc.data[:, is_dead] = (replacement_values * W_enc_norm_alive_mean * neuron_resample_scale).T
self.b_enc.data[is_dead] = 0.0


# reset the Adam Optimiser for every modified weight and bias term
# Reset all the Adam parameters
for dict_idx, (k, v) in enumerate(optimizer.state.items()):
for v_key in ["exp_avg", "exp_avg_sq"]:
if dict_idx == 0:
assert k.data.shape == (self.d_in, self.d_sae)
v[v_key][:, is_dead] = 0.0
elif dict_idx == 1:
assert k.data.shape == (self.d_sae,)
v[v_key][is_dead] = 0.0
elif dict_idx == 2:
assert k.data.shape == (self.d_sae, self.d_in)
v[v_key][is_dead, :] = 0.0
elif dict_idx == 3:
assert k.data.shape == (self.d_in,)
else:
raise ValueError(f"Unexpected dict_idx {dict_idx}")

# Check that the opt is really updated
for dict_idx, (k, v) in enumerate(optimizer.state.items()):
for v_key in ["exp_avg", "exp_avg_sq"]:
if dict_idx == 0:
if k.data.shape != (self.d_in, self.d_sae):
print(
"Warning: it does not seem as if resetting the Adam parameters worked, there are shapes mismatches"
)
if v[v_key][:, replacement_indices].abs().max().item() > 1e-6:
print(
"Warning: it does not seem as if resetting the Adam parameters worked"
)

return n_samples
return n_resampled_neurons

@torch.no_grad()
def set_decoder_norm_to_unit_norm(self):
Expand Down
28 changes: 17 additions & 11 deletions sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,44 @@

@dataclass
class SAEToyModelRunnerConfig:

# ReLu Model Parameters
n_features: int = 5
n_hidden: int = 2
n_correlated_pairs: int = 0
n_anticorrelated_pairs: int = 0
feature_probability: float = 0.025
# Relu Model Training Parameters
model_training_steps: int = 10_000

# SAE Parameters
d_sae: int = 5

# Training Parameters
n_sae_training_tokens: int = 25_000
l1_coefficient: float = 1e-3
lr: float = 3e-4
train_batch_size: int = 1024 # Shouldn't be as big as the batch size for language models
train_epochs: int = 10
train_batch_size: int = 1024

# Resampling protocol args
feature_sampling_method: str = "l2" # None or l2
feature_sampling_window: int = 100
feature_reinit_scale: float = 0.2
dead_feature_window: int = 100 # unless this window is larger feature sampling,
dead_feature_threshold: float = 1e-8

# Activation Store Parameters
total_training_tokens: int = 25_000

# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_toy_model"
wandb_entity: str = None
wandb_log_frequency: int = 50

# Misc
device: str = "cpu"
seed: int = 42
checkpoint_path: str = "checkpoints"
dtype: torch.dtype = (
torch.float32
) # TODO: Make this a string (have a dictionary to map)
dtype: torch.dtype = torch.float32

def __post_init__(self):
self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE
Expand Down Expand Up @@ -72,7 +79,7 @@ def toy_model_sae_runner(cfg):
model.optimize(steps=cfg.model_training_steps)

# Generate Training Data
batch = model.generate_batch(cfg.n_sae_training_tokens)
batch = model.generate_batch(cfg.total_training_tokens)
hidden = einops.einsum(
batch,
model.W,
Expand All @@ -82,18 +89,17 @@ def toy_model_sae_runner(cfg):
sparse_autoencoder = SparseAutoencoder(cfg) # config has the hyperparameters for the SAE

if cfg.log_to_wandb:
wandb.init(project="sae-training-test", config=cfg)
wandb.init(project=cfg.wandb_project, config=cfg)

sparse_autoencoder = train_toy_sae(
model, # need model so we can do evals for neuron resampling
sparse_autoencoder,
hidden.detach().squeeze(),
use_wandb=cfg.log_to_wandb,
batch_size=cfg.train_batch_size,
n_epochs=cfg.train_epochs,
feature_sampling_window=cfg.feature_sampling_window,
feature_reinit_scale=cfg.feature_reinit_scale,
dead_feature_threshold=cfg.dead_feature_threshold,
use_wandb=cfg.log_to_wandb,
wandb_log_frequency=cfg.wandb_log_frequency,
)

Expand Down
29 changes: 12 additions & 17 deletions sae_training/toy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,22 @@
'''
from dataclasses import dataclass
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from torch import nn, Tensor
import torch as t
from tqdm import tqdm
from typing import Callable, List, Optional, Tuple, Union

import einops
from torch.nn import functional as F
import torch as t
from torch import Tensor
from IPython.display import clear_output
from typing import List, Union, Optional
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from typing import Tuple, List
from jaxtyping import Float
import einops
import torch as t
from IPython.display import clear_output
from jaxtyping import Float, Int
from matplotlib import pyplot as plt
from matplotlib.widgets import Slider # , Button
from matplotlib.animation import FuncAnimation
from matplotlib.widgets import Slider # , Button
from plotly.subplots import make_subplots
from torch import Tensor, nn
from torch.nn import functional as F
from tqdm import tqdm

device = "cpu"

Expand Down Expand Up @@ -187,7 +182,7 @@ def optimize(
'''
optimizer = t.optim.Adam(list(self.parameters()), lr=lr)

progress_bar = tqdm(range(steps))
progress_bar = tqdm(range(steps), desc="Training Toy Model")

for step in progress_bar:

Expand Down
45 changes: 15 additions & 30 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def train_sae_on_language_model(
optimizer = torch.optim.Adam(sparse_autoencoder.parameters())
sparse_autoencoder.train()

frac_active_list = [] # track active features


# track active features
act_freq_scores = torch.zeros(sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device)
n_frac_active_tokens = 0

total_training_tokens = sparse_autoencoder.cfg.total_training_tokens
n_training_steps = 0
n_training_tokens = 0
Expand All @@ -50,15 +51,16 @@ def train_sae_on_language_model(
if (feature_sampling_method is not None) and ((n_training_steps + 1) % dead_feature_window == 0):

# Get the fraction of neurons active in the previous window
frac_active_in_window = torch.stack(frac_active_list[-dead_feature_window:], dim=0)
feature_sparsity = frac_active_in_window.sum(0) / (
dead_feature_window * batch_size
)
feature_sparsity = act_freq_scores / n_frac_active_tokens
# is_dead = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold)

# if standard resampling <- do this
n_resampled_neurons = sparse_autoencoder.resample_neurons(
activation_store.next_batch(),
feature_sparsity,
feature_reinit_scale)
feature_reinit_scale,
optimizer
)

else:
n_resampled_neurons = 0
Expand All @@ -72,32 +74,14 @@ def train_sae_on_language_model(

with torch.no_grad():
# Calculate the sparsities, and add it to a list, calculate sparsity metrics
act_freq_scores = (feature_acts.abs() > 0).float().sum(0)
frac_active_list.append(act_freq_scores)

if len(frac_active_list) > feature_sampling_window:
frac_active_in_window = torch.stack(
frac_active_list[-feature_sampling_window:], dim=0
)
feature_sparsity = frac_active_in_window.sum(0) / (
feature_sampling_window * batch_size
)
else:
# use the whole list
frac_active_in_window = torch.stack(
frac_active_list, dim=0)
feature_sparsity = frac_active_in_window.sum(0) / (
len(frac_active_list) * batch_size
)
act_freq_scores += (feature_acts.abs() > 0).float().sum(0)
n_frac_active_tokens += batch_size
feature_sparsity = act_freq_scores / n_frac_active_tokens

# metrics for currents acts
l0 = (feature_acts > 0).float().sum(1).mean()
l2_norm = torch.norm(feature_acts, dim=1).mean()

# don't want to risk not see these.
if use_wandb:
wandb.log({"metrics/n_resampled_neurons": n_resampled_neurons}, n_training_steps)

if use_wandb and ((n_training_steps + 1) % wandb_log_frequency == 0):
wandb.log(
{
Expand All @@ -121,12 +105,13 @@ def train_sae_on_language_model(
.mean()
.item(),
"details/n_training_tokens": n_training_tokens,
"metrics/n_resampled_neurons": n_resampled_neurons,
},
step=n_training_steps,
)

if (n_training_steps + 1) % (wandb_log_frequency * 100) == 0:
log_feature_sparsity = torch.log(feature_sparsity + 1e-8)
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10)
wandb.log(
{
"plots/feature_density_histogram": wandb.Histogram(
Expand Down
Loading

0 comments on commit 3843c39

Please sign in to comment.