Skip to content

Commit

Permalink
make toy model runner
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Nov 30, 2023
1 parent a61b75f commit 4851dd1
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 24 deletions.
19 changes: 19 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Looks like setup.cfg cannot load the extensions in the precommit,
# but that the pylintrc file can
# This happens even when specifying --rcfile=setup.cfg
# Possible bug from pylint?
[MASTER]
load-plugins = pylint.extensions.docparams, pylint.extensions.docstyle, pylint.extensions.mccabe

[BASIC]
accept-no-param-doc = no
accept-no-raise-doc = no
accept-no-return-doc = no
accept-no-yields-doc = no
default-docstring-type = numpy

[FORMAT]
max-line-length = 88

[MESSAGES CONTROL]
disable = C0330, C0326, C0199, C0411
15 changes: 14 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,18 @@
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,

"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": false,
"editor.codeActionsOnSave": {
"source.organizeImports": true
},
},
"isort.args": [
"--profile",
"black"
],
"editor.defaultFormatter": "mikoz.black-py",
}
2 changes: 0 additions & 2 deletions pytest.ini

This file was deleted.

12 changes: 7 additions & 5 deletions sae_training/SAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
https://github.com/ArthurConmy/sae/blob/main/sae/model.py
"""
from typing import Literal

import einops
import torch
from torch import nn
import einops
from transformer_lens.hook_points import HookedRootModule, HookPoint


#%%
# TODO make sure that W_dec stays unit norm during training
class SAE(HookedRootModule):
Expand All @@ -18,14 +20,14 @@ def __init__(
):
super().__init__()
self.cfg = cfg
self.d_in = cfg["d_in"]
self.d_in = cfg.d_in
if not isinstance(self.d_in, int):
raise ValueError(
f"d_in must be an int but was {self.d_in=}; {type(self.d_in)=}"
)
self.d_sae = cfg["d_sae"]
self.dtype = cfg["dtype"]
self.device = cfg["device"]
self.d_sae = cfg.d_sae
self.dtype = cfg.dtype
self.device = cfg.device

# NOTE: if using resampling neurons method, you must ensure that we initialise the weights in the order W_enc, b_enc, W_dec, b_dec
self.W_enc = nn.Parameter(
Expand Down
46 changes: 46 additions & 0 deletions sae_training/lm_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from datasets import load_dataset

# To do: preprocess_tokenized_dataset, preprocess_text_dataset, preprocess other dataset
def preprocess_tokenized_dataset(source_batch: dict, context_size: int) -> dict:
tokenized_prompts = source_batch["tokens"]

# Chunk each tokenized prompt into blocks of context_size,
# discarding the last block if too small.
context_size_prompts = []
for encoding in tokenized_prompts:
chunks = [
encoding[i : i + context_size]
for i in range(0, len(encoding), context_size)
if len(encoding[i : i + context_size]) == context_size
]
context_size_prompts.extend(chunks)

return {"input_ids": context_size_prompts}


def get_mapped_dataset(cfg):
# Load the dataset
context_size = cfg["context_size"]
dataset_path = cfg["dataset_path"]
dataset_split = "train"
buffer_size: int = 1000,
preprocess_batch_size: int = 1000,

dataset = load_dataset(dataset_path, streaming=True, split=dataset_split) # type: ignore

# Setup preprocessing
existing_columns = list(next(iter(dataset)).keys())
mapped_dataset = dataset.map(
preprocess_tokenized_dataset, # preprocess is what differentiates different datasets
batched=True,
batch_size=preprocess_batch_size,
fn_kwargs={"context_size": context_size},
remove_columns=existing_columns,
)

# Setup approximate shuffling. As the dataset is streamed, this just pre-downloads at least
# `buffer_size` items and then shuffles just that buffer.
# https://huggingface.co/docs/datasets/v2.14.5/stream#shuffle
dataset = mapped_dataset.shuffle(buffer_size=buffer_size)
return dataset

67 changes: 67 additions & 0 deletions sae_training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from dataclasses import dataclass

import torch
from transformer_lens import HookedTransformer

from sae_training.activation_store import ActivationStore
from sae_training.lm_datasets import get_mapped_dataset
from sae_training.SAE import SAE
from sae_training.train_sae import train_sae


@dataclass
class SAERunnerConfig:

# Data Generating Function (Model + Training Distibuion)
model_name: str = "gelu-2l"
hook_point: str = "blocks.0.hook_mlp_out"
hook_point_layer: int = 0
dataset_path: str = "NeelNanda/c4-tokenized-2b"

# SAE Parameters
expansion_factor: int = 4

# Training Parameters
l1_coefficient: float = 1e-3
lr: float = 3e-4
train_batch_size: int = 4096
context_size: int = 128

# Activation Store Parameters
# max_store_size: int = 384 * 4096 * 2
# max_activations: int = 2_000_000_000
# resample_frequency: int = 122_880_000
# checkpoint_frequency: int = 100_000_000
# validation_frequency: int = 384 * 4096 * 2 * 100

# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training"
wandb_entity: str = None

# Misc
device: str = "cpu"
seed: int = 42
checkpoint_path: str = "checkpoints"
dtype: torch.dtype = torch.float32

def sae_runner(cfg):


model = HookedTransformer.from_pretrained("gelu-2l") # any other cfg we should pass in here?

# initialize dataset
dataset = get_mapped_dataset(cfg)
activation_store = ActivationStore(cfg, dataset)

# initialize the SAE
sparse_autoencoder = SAE(cfg)

# train SAE
sparse_autoencoder = train_sae(
model,
activation_store,
sparse_autoencoder,
cfg)

return trained_sae
97 changes: 97 additions & 0 deletions sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from dataclasses import dataclass

import einops
import torch
from transformer_lens import HookedTransformer

import wandb
from sae_training.SAE import SAE
from sae_training.toy_models import Config as ToyConfig
from sae_training.toy_models import Model as ToyModel
from sae_training.train_sae import train_sae


@dataclass
class SAEToyModelRunnerConfig:
# ReLu Model Parameters
n_features: int = 5
n_hidden: int = 2
n_correlated_pairs: int = 0
n_anticorrelated_pairs: int = 0
feature_probability: float = 0.025
# Relu Model Training Parameters
model_training_steps: int = 10_000
# SAE Parameters
expansion_factor: int = 4
# Training Parameters
n_sae_training_tokens: int = 25_000
l1_coefficient: float = 1e-3
lr: float = 3e-4
train_batch_size: int = 32 # Shouldn't be as big as the batch size for language models
train_epochs: int = 10
# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_toy_model"
wandb_entity: str = None
# Misc
device: str = "cpu"
seed: int = 42
checkpoint_path: str = "checkpoints"
dtype: torch.dtype = (
torch.float32
) # TODO: Make this a string (have a dictionary to map)

def __post_init__(self):
self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE
self.d_sae = self.n_hidden * self.expansion_factor


def toy_model_sae_runner(cfg):
'''
A runner for training an SAE on a toy model.
'''
# Toy Model Config
toy_model_cfg = ToyConfig(
n_instances=1, # Not set up to train > 1 SAE so shouldn't do > 1 model.
n_features=cfg.n_features,
n_hidden=cfg.n_hidden,
n_correlated_pairs=cfg.n_correlated_pairs,
n_anticorrelated_pairs=cfg.n_anticorrelated_pairs,
)

# Initialize Toy Model
model = ToyModel(
cfg=toy_model_cfg,
device="cpu",
feature_probability=cfg.feature_probability,
)

# Train the Toy Model
model.optimize(steps=cfg.model_training_steps)

# Generate Training Data
batch = model.generate_batch(cfg.n_sae_training_tokens)
hidden = einops.einsum(
batch,
model.W,
"batch_size instances features, instances hidden features -> batch_size instances hidden",
)

sae = SAE(cfg) # config has the hyperparameters for the SAE

if cfg.log_to_wandb:
wandb.init(project="sae-training-test", config=cfg)

sae = train_sae(
sae,
hidden.detach().squeeze(),
use_wandb=cfg.log_to_wandb,
l1_coeff=cfg.l1_coefficient,
batch_size=cfg.train_batch_size,
n_epochs=cfg.train_epochs,
)

if cfg.log_to_wandb:
wandb.finish()

return sae
47 changes: 31 additions & 16 deletions sae_training/train_sae.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
#%%
import einops
import torch
from torch.utils.data import DataLoader
import einops
from tqdm import tqdm

import wandb
from sae_training.SAE import SAE
from sae_training.activation_store import ActivationStore
from sae_training.SAE import SAE


#%%
def train_sae(sae: SAE,
activation_store: ActivationStore,
n_epochs: int = 10,
batch_size: int = 32,
l1_coeff: float = 0.001,
use_wandb: bool = False):
use_wandb: bool = False,
wandb_log_freq: int = 10,):
"""
Takes an SAE and a bunch of activations and does a bunch of training steps
"""
Expand All @@ -22,8 +26,10 @@ def train_sae(sae: SAE,
optimizer = torch.optim.Adam(sae.parameters())

sae.train()
for _ in range(n_epochs):
for batch in dataloader:
n_training_steps = 0
for epoch in range(n_epochs):
pbar = tqdm(dataloader)
for step, batch in enumerate(pbar):
optimizer.zero_grad()

sae_out, hidden_post = sae(batch)
Expand All @@ -33,19 +39,26 @@ def train_sae(sae: SAE,
loss = mse_loss + l1_coeff * l1_loss

with torch.no_grad():

batch_size = batch.shape[0]
feature_mean_activation = hidden_post.mean(dim=0)
n_dead_features = (feature_mean_activation == 0).sum().item()
if use_wandb:
wandb.log({
"mse_loss": mse_loss.item(),
"l1_loss": l1_loss.item(),
"loss": loss.item(),
"l0": ((hidden_post != 0) / batch_size).sum().item(),
"l2": torch.norm(hidden_post, dim=1).mean().item(),
"hist": wandb.Histogram(feature_mean_activation.tolist()),
"n_dead_features": n_dead_features,
})
n_dead_features = (feature_mean_activation == 0).sum()
l0 = ((hidden_post != 0) / batch_size).sum()
l2_norm = torch.norm(hidden_post, dim=1).mean()


if use_wandb and (step % wandb_log_freq == 0):
wandb.log({
"mse_loss": mse_loss.item(),
"l1_loss": l1_loss.item(),
"loss": loss.item(),
"l0": l0.item(),
"l2": l2_norm.item(),
"hist": wandb.Histogram(feature_mean_activation.tolist()),
"n_dead_features": n_dead_features,
}, step=n_training_steps)

pbar.set_description(f"Epoch {epoch} | Step {step} | MSE Loss {mse_loss.item():.3f} | L1 Loss {l1_loss.item():.3f} | L0 {l0.item():.3f} | n_dead_features {n_dead_features}")

loss.backward()

Expand All @@ -70,6 +83,8 @@ def train_sae(sae: SAE,
# Make sure the W_dec is still zero-norm
with torch.no_grad():
sae.W_dec.data /= (torch.norm(sae.W_dec.data, dim=1, keepdim=True) + 1e-8)

n_training_steps += 1


return sae
Empty file added tests/benchmark/__init__.py
Empty file.
Loading

0 comments on commit 4851dd1

Please sign in to comment.