Skip to content

Commit

Permalink
progress on implementing multi-sae support
Browse files Browse the repository at this point in the history
  • Loading branch information
Benw8888 committed Feb 26, 2024
1 parent f3fe937 commit 2ba2131
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 196 deletions.
41 changes: 41 additions & 0 deletions activation_storing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import sys

import torch

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.cache_activations_runner import cache_activations_runner
from sae_training.config import (
CacheActivationsRunnerConfig,
LanguageModelSAERunnerConfig,
)
from sae_training.lm_runner import language_model_sae_runner

cfg = CacheActivationsRunnerConfig(

# Data Generating Function (Model + Training Distibuion)
model_name = "gpt2-small",
hook_point = f"blocks.{3}.hook_resid_pre",
hook_point_layer = 3,
d_in = 768,
dataset_path = "Skylion007/openwebtext",
is_dataset_tokenized=True,
cached_activations_path="activations/",

# Activation Store Parameters
n_batches_in_buffer = 16,
total_training_tokens = 300_000_000,
store_batch_size = 64,

# Activation caching shuffle parameters
n_shuffles_final = 16,

# Misc
device = "cuda",
seed = 42,
dtype = torch.bfloat16,
)

cache_activations_runner(cfg)
59 changes: 59 additions & 0 deletions lp_sae_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import os
import sys

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

# Data Generating Function (Model + Training Distibuion)
model_name = "gpt2-small",
hook_point = "blocks.2.hook_resid_pre",
hook_point_layer = 2,
d_in = 768,
dataset_path = "Skylion007/openwebtext",
is_dataset_tokenized=False,

# SAE Parameters
expansion_factor = 64,
b_dec_init_method = "geometric_median",

# Training Parameters
lr = 4e-4,
l1_coefficient = 8e-5,
lr_scheduler_name="constantwithwarmup",
train_batch_size = 4096,
context_size = 128,
lr_warm_up_steps=5000,

# Activation Store Parameters
n_batches_in_buffer = 128,
total_training_tokens = 1_000_000 * 300,
store_batch_size = 32,

# Dead Neurons and Sparsity
use_ghost_grads=True,
feature_sampling_window = 1000,
dead_feature_window=5000,
dead_feature_threshold = 1e-6,

# WANDB
log_to_wandb = True,
wandb_project= "mats_sae_training_gpt2",
wandb_entity = None,
wandb_log_frequency=100,

# Misc
device = "cuda",
seed = 42,
n_checkpoints = 10,
checkpoint_path = "checkpoints",
dtype = torch.float32,
use_cached_activations = True,
)

sparse_autoencoder = language_model_sae_runner(cfg)
98 changes: 51 additions & 47 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ def __init__(
self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True)
self.iterable_dataset = iter(self.dataset)

# check if it's tokenized
if "tokens" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = True
print("Dataset is tokenized! Updating config.")
elif "text" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = False
print("Dataset is not tokenized! Updating config.")
# Check if dataset is tokenized
dataset_sample = next(self.iterable_dataset)
self.cfg.is_dataset_tokenized = "tokens" in dataset_sample.keys()
print(f"Dataset is {'tokenized' if self.cfg.is_dataset_tokenized else 'not tokenized'}! Updating config.")
self.iterable_dataset = iter(self.dataset) # Reset iterator after checking

if self.cfg.use_cached_activations:
if self.cfg.use_cached_activations: #EDIT: load from multi-layer acts
# Sanity check: does the cache directory exist?
assert os.path.exists(
self.cfg.cached_activations_path
Expand Down Expand Up @@ -145,39 +143,52 @@ def get_batch_tokens(self):
return batch_tokens[:batch_size]

def get_activations(self, batch_tokens, get_loss=False):
act_name = self.cfg.hook_point
hook_point_layer = self.cfg.hook_point_layer
"""
Returns activations of shape (batches, context, num_layers, d_in)
"""
layers = self.cfg.hook_point_layer if isinstance(self.cfg.hook_point_layer, list) else [self.cfg.hook_point_layer]
act_names = [self.cfg.hook_point.format(layer = layer) for layer in layers]
hook_point_max_layer = self.cfg.hook_point_layer
if self.cfg.hook_point_head_index is not None:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name][:, :, self.cfg.hook_point_head_index]
layerwise_activations = self.model.run_with_cache(
batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1
)[1]
activations_list = [
layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index]
for act_name in act_names
]
else:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name]

return activations

layerwise_activations = self.model.run_with_cache(
batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1
)[1]
activations_list = [
layerwise_activations[act_name]
for act_name in act_names
]

# Stack along a new dimension to keep separate layers distinct
stacked_activations = torch.stack(activations_list, dim=2)

return stacked_activations

def get_buffer(self, n_batches_in_buffer):
context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = len(self.cfg.hook_points) # Number of hook points or layers

if self.cfg.use_cached_activations:
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor (flattened along all dims except d_in)
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, d_in), dtype=self.cfg.dtype, device=self.cfg.device
(buffer_size, num_layers, d_in), dtype=self.cfg.dtype, device=self.cfg.device
)
n_tokens_filled = 0

# The activations may be split across multiple files,
# Or we might only want a subset of one file (depending on the sizes)
# Assume activations for different layers are stored separately and need to be combined
while n_tokens_filled < buffer_size:
# Load the next file
# Make sure it exists
if not os.path.exists(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
):
Expand All @@ -193,55 +204,47 @@ def get_buffer(self, n_batches_in_buffer):
)
print(f"Returning a buffer of size {n_tokens_filled} instead.")
print("\n\n")
new_buffer = new_buffer[:n_tokens_filled]
break
new_buffer = new_buffer[:n_tokens_filled, ...]
return new_buffer

activations = torch.load(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
)

# If we only want a subset of the file, take it
taking_subset_of_file = False
if n_tokens_filled + activations.shape[0] > buffer_size:
activations = activations[: buffer_size - n_tokens_filled]
activations = activations[: buffer_size - n_tokens_filled, ...]
taking_subset_of_file = True

# Add it to the buffer
new_buffer[
n_tokens_filled : n_tokens_filled + activations.shape[0]
] = activations
new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0], ...] = activations

# Update counters
n_tokens_filled += activations.shape[0]
if taking_subset_of_file:
self.next_idx_within_buffer = activations.shape[0]
else:
self.next_cache_idx += 1
self.next_idx_within_buffer = 0

n_tokens_filled += activations.shape[0]

return new_buffer

refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
# refill_iterator = tqdm(refill_iterator, desc="generate activations")

# Initialize empty tensor buffer of the maximum required size
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, d_in),
(total_size, context_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
)

# Insert activations directly into pre-allocated buffer
# pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")

for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
] = refill_activations

# pbar.update(1)

new_buffer = new_buffer.reshape(-1, d_in)
new_buffer = new_buffer.reshape(-1, num_layers, d_in)
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

return new_buffer
Expand All @@ -261,10 +264,11 @@ def get_data_loader(

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer]
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer],
dim=0,
)

mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[1])]

# 2. put 50 % in storage
self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
Expand Down
3 changes: 2 additions & 1 deletion sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RunnerConfig(ABC):

# Data Generating Function (Model + Training Distibuion)
model_name: str = "gelu-2l"
hook_point: str = "blocks.0.hook_mlp_out"
hook_point: str = "blocks.{layer}.hook_mlp_out"
hook_point_layer: int = 0
hook_point_head_index: Optional[int] = None
dataset_path: str = "NeelNanda/c4-tokenized-2b"
Expand Down Expand Up @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig(RunnerConfig):

# Training Parameters
l1_coefficient: float = 1e-3
lp_norm: float = 1
lr: float = 3e-4
lr_scheduler_name: str = "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_warm_up_steps: int = 500
Expand Down
38 changes: 38 additions & 0 deletions sae_training/sae_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from functools import partial
import torch.nn.functional as F
import dataclasses
from sae_training.sparse_autoencoder import SparseAutoencoder


class SAEGroup:
def __init__(self, cfg):
self.cfg = cfg
self.autoencoders = [] # This will store tuples of (instance, hyperparameters)
self._init_autoencoders(cfg)

def _init_autoencoders(self, cfg):
# Dynamically get all combinations of hyperparameters from cfg
from itertools import product

# Extract all hyperparameter lists from cfg
hyperparameters = {k: v for k, v in vars(cfg).items() if isinstance(v, list)}
keys, values = zip(*hyperparameters.items())

# Create all combinations of hyperparameters
for combination in product(*values):
cfg_copy = dataclasses.replace(cfg)
params = dict(zip(keys, combination))
cfg_copy.update(params)
# Insert the layer into the hookpoint
cfg_copy.hook_point = cfg_copy.hook_point.format(layer=cfg.copy.hook_point_layer)
# Create and store both the SparseAutoencoder instance and its parameters
self.autoencoders.append((SparseAutoencoder(cfg_copy), cfg_copy))

def __iter__(self):
# Make SAEGroup iterable over its SparseAutoencoder instances and their parameters
for ae, params in self.autoencoders:
yield ae, params # Yielding as a tuple

def __len__(self):
# Return the number of SparseAutoencoder instances
return len(self.autoencoders)
3 changes: 2 additions & 1 deletion sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
)
self.d_sae = cfg.d_sae
self.l1_coefficient = cfg.l1_coefficient
self.lp_norm = cfg.lp_norm
self.dtype = cfg.dtype
self.device = cfg.device

Expand Down Expand Up @@ -132,7 +133,7 @@ def forward(self, x, dead_neuron_mask=None):
mse_loss_ghost_resid = mse_loss_ghost_resid.mean()

mse_loss = mse_loss.mean()
sparsity = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
sparsity = feature_acts.norm(p=self.lp_norm, dim=1).mean(dim=(0,))
l1_loss = self.l1_coefficient * sparsity
loss = mse_loss + l1_loss + mse_loss_ghost_resid

Expand Down
Loading

0 comments on commit 2ba2131

Please sign in to comment.