-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jbloom-md
committed
Nov 30, 2023
1 parent
a61b75f
commit 4851dd1
Showing
10 changed files
with
306 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Looks like setup.cfg cannot load the extensions in the precommit, | ||
# but that the pylintrc file can | ||
# This happens even when specifying --rcfile=setup.cfg | ||
# Possible bug from pylint? | ||
[MASTER] | ||
load-plugins = pylint.extensions.docparams, pylint.extensions.docstyle, pylint.extensions.mccabe | ||
|
||
[BASIC] | ||
accept-no-param-doc = no | ||
accept-no-raise-doc = no | ||
accept-no-return-doc = no | ||
accept-no-yields-doc = no | ||
default-docstring-type = numpy | ||
|
||
[FORMAT] | ||
max-line-length = 88 | ||
|
||
[MESSAGES CONTROL] | ||
disable = C0330, C0326, C0199, C0411 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from datasets import load_dataset | ||
|
||
# To do: preprocess_tokenized_dataset, preprocess_text_dataset, preprocess other dataset | ||
def preprocess_tokenized_dataset(source_batch: dict, context_size: int) -> dict: | ||
tokenized_prompts = source_batch["tokens"] | ||
|
||
# Chunk each tokenized prompt into blocks of context_size, | ||
# discarding the last block if too small. | ||
context_size_prompts = [] | ||
for encoding in tokenized_prompts: | ||
chunks = [ | ||
encoding[i : i + context_size] | ||
for i in range(0, len(encoding), context_size) | ||
if len(encoding[i : i + context_size]) == context_size | ||
] | ||
context_size_prompts.extend(chunks) | ||
|
||
return {"input_ids": context_size_prompts} | ||
|
||
|
||
def get_mapped_dataset(cfg): | ||
# Load the dataset | ||
context_size = cfg["context_size"] | ||
dataset_path = cfg["dataset_path"] | ||
dataset_split = "train" | ||
buffer_size: int = 1000, | ||
preprocess_batch_size: int = 1000, | ||
|
||
dataset = load_dataset(dataset_path, streaming=True, split=dataset_split) # type: ignore | ||
|
||
# Setup preprocessing | ||
existing_columns = list(next(iter(dataset)).keys()) | ||
mapped_dataset = dataset.map( | ||
preprocess_tokenized_dataset, # preprocess is what differentiates different datasets | ||
batched=True, | ||
batch_size=preprocess_batch_size, | ||
fn_kwargs={"context_size": context_size}, | ||
remove_columns=existing_columns, | ||
) | ||
|
||
# Setup approximate shuffling. As the dataset is streamed, this just pre-downloads at least | ||
# `buffer_size` items and then shuffles just that buffer. | ||
# https://huggingface.co/docs/datasets/v2.14.5/stream#shuffle | ||
dataset = mapped_dataset.shuffle(buffer_size=buffer_size) | ||
return dataset | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
from transformer_lens import HookedTransformer | ||
|
||
from sae_training.activation_store import ActivationStore | ||
from sae_training.lm_datasets import get_mapped_dataset | ||
from sae_training.SAE import SAE | ||
from sae_training.train_sae import train_sae | ||
|
||
|
||
@dataclass | ||
class SAERunnerConfig: | ||
|
||
# Data Generating Function (Model + Training Distibuion) | ||
model_name: str = "gelu-2l" | ||
hook_point: str = "blocks.0.hook_mlp_out" | ||
hook_point_layer: int = 0 | ||
dataset_path: str = "NeelNanda/c4-tokenized-2b" | ||
|
||
# SAE Parameters | ||
expansion_factor: int = 4 | ||
|
||
# Training Parameters | ||
l1_coefficient: float = 1e-3 | ||
lr: float = 3e-4 | ||
train_batch_size: int = 4096 | ||
context_size: int = 128 | ||
|
||
# Activation Store Parameters | ||
# max_store_size: int = 384 * 4096 * 2 | ||
# max_activations: int = 2_000_000_000 | ||
# resample_frequency: int = 122_880_000 | ||
# checkpoint_frequency: int = 100_000_000 | ||
# validation_frequency: int = 384 * 4096 * 2 * 100 | ||
|
||
# WANDB | ||
log_to_wandb: bool = True | ||
wandb_project: str = "mats_sae_training" | ||
wandb_entity: str = None | ||
|
||
# Misc | ||
device: str = "cpu" | ||
seed: int = 42 | ||
checkpoint_path: str = "checkpoints" | ||
dtype: torch.dtype = torch.float32 | ||
|
||
def sae_runner(cfg): | ||
|
||
|
||
model = HookedTransformer.from_pretrained("gelu-2l") # any other cfg we should pass in here? | ||
|
||
# initialize dataset | ||
dataset = get_mapped_dataset(cfg) | ||
activation_store = ActivationStore(cfg, dataset) | ||
|
||
# initialize the SAE | ||
sparse_autoencoder = SAE(cfg) | ||
|
||
# train SAE | ||
sparse_autoencoder = train_sae( | ||
model, | ||
activation_store, | ||
sparse_autoencoder, | ||
cfg) | ||
|
||
return trained_sae |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from dataclasses import dataclass | ||
|
||
import einops | ||
import torch | ||
from transformer_lens import HookedTransformer | ||
|
||
import wandb | ||
from sae_training.SAE import SAE | ||
from sae_training.toy_models import Config as ToyConfig | ||
from sae_training.toy_models import Model as ToyModel | ||
from sae_training.train_sae import train_sae | ||
|
||
|
||
@dataclass | ||
class SAEToyModelRunnerConfig: | ||
# ReLu Model Parameters | ||
n_features: int = 5 | ||
n_hidden: int = 2 | ||
n_correlated_pairs: int = 0 | ||
n_anticorrelated_pairs: int = 0 | ||
feature_probability: float = 0.025 | ||
# Relu Model Training Parameters | ||
model_training_steps: int = 10_000 | ||
# SAE Parameters | ||
expansion_factor: int = 4 | ||
# Training Parameters | ||
n_sae_training_tokens: int = 25_000 | ||
l1_coefficient: float = 1e-3 | ||
lr: float = 3e-4 | ||
train_batch_size: int = 32 # Shouldn't be as big as the batch size for language models | ||
train_epochs: int = 10 | ||
# WANDB | ||
log_to_wandb: bool = True | ||
wandb_project: str = "mats_sae_training_toy_model" | ||
wandb_entity: str = None | ||
# Misc | ||
device: str = "cpu" | ||
seed: int = 42 | ||
checkpoint_path: str = "checkpoints" | ||
dtype: torch.dtype = ( | ||
torch.float32 | ||
) # TODO: Make this a string (have a dictionary to map) | ||
|
||
def __post_init__(self): | ||
self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE | ||
self.d_sae = self.n_hidden * self.expansion_factor | ||
|
||
|
||
def toy_model_sae_runner(cfg): | ||
''' | ||
A runner for training an SAE on a toy model. | ||
''' | ||
# Toy Model Config | ||
toy_model_cfg = ToyConfig( | ||
n_instances=1, # Not set up to train > 1 SAE so shouldn't do > 1 model. | ||
n_features=cfg.n_features, | ||
n_hidden=cfg.n_hidden, | ||
n_correlated_pairs=cfg.n_correlated_pairs, | ||
n_anticorrelated_pairs=cfg.n_anticorrelated_pairs, | ||
) | ||
|
||
# Initialize Toy Model | ||
model = ToyModel( | ||
cfg=toy_model_cfg, | ||
device="cpu", | ||
feature_probability=cfg.feature_probability, | ||
) | ||
|
||
# Train the Toy Model | ||
model.optimize(steps=cfg.model_training_steps) | ||
|
||
# Generate Training Data | ||
batch = model.generate_batch(cfg.n_sae_training_tokens) | ||
hidden = einops.einsum( | ||
batch, | ||
model.W, | ||
"batch_size instances features, instances hidden features -> batch_size instances hidden", | ||
) | ||
|
||
sae = SAE(cfg) # config has the hyperparameters for the SAE | ||
|
||
if cfg.log_to_wandb: | ||
wandb.init(project="sae-training-test", config=cfg) | ||
|
||
sae = train_sae( | ||
sae, | ||
hidden.detach().squeeze(), | ||
use_wandb=cfg.log_to_wandb, | ||
l1_coeff=cfg.l1_coefficient, | ||
batch_size=cfg.train_batch_size, | ||
n_epochs=cfg.train_epochs, | ||
) | ||
|
||
if cfg.log_to_wandb: | ||
wandb.finish() | ||
|
||
return sae |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.