-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
progress on implementing multi-sae support
- Loading branch information
Benw8888
committed
Feb 26, 2024
1 parent
f3fe937
commit 2ba2131
Showing
7 changed files
with
359 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
os.environ["WANDB__SERVICE_WAIT"] = "300" | ||
|
||
from sae_training.cache_activations_runner import cache_activations_runner | ||
from sae_training.config import ( | ||
CacheActivationsRunnerConfig, | ||
LanguageModelSAERunnerConfig, | ||
) | ||
from sae_training.lm_runner import language_model_sae_runner | ||
|
||
cfg = CacheActivationsRunnerConfig( | ||
|
||
# Data Generating Function (Model + Training Distibuion) | ||
model_name = "gpt2-small", | ||
hook_point = f"blocks.{3}.hook_resid_pre", | ||
hook_point_layer = 3, | ||
d_in = 768, | ||
dataset_path = "Skylion007/openwebtext", | ||
is_dataset_tokenized=True, | ||
cached_activations_path="activations/", | ||
|
||
# Activation Store Parameters | ||
n_batches_in_buffer = 16, | ||
total_training_tokens = 300_000_000, | ||
store_batch_size = 64, | ||
|
||
# Activation caching shuffle parameters | ||
n_shuffles_final = 16, | ||
|
||
# Misc | ||
device = "cuda", | ||
seed = 42, | ||
dtype = torch.bfloat16, | ||
) | ||
|
||
cache_activations_runner(cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import os | ||
import sys | ||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
os.environ["WANDB__SERVICE_WAIT"] = "300" | ||
|
||
from sae_training.config import LanguageModelSAERunnerConfig | ||
from sae_training.lm_runner import language_model_sae_runner | ||
|
||
cfg = LanguageModelSAERunnerConfig( | ||
|
||
# Data Generating Function (Model + Training Distibuion) | ||
model_name = "gpt2-small", | ||
hook_point = "blocks.2.hook_resid_pre", | ||
hook_point_layer = 2, | ||
d_in = 768, | ||
dataset_path = "Skylion007/openwebtext", | ||
is_dataset_tokenized=False, | ||
|
||
# SAE Parameters | ||
expansion_factor = 64, | ||
b_dec_init_method = "geometric_median", | ||
|
||
# Training Parameters | ||
lr = 4e-4, | ||
l1_coefficient = 8e-5, | ||
lr_scheduler_name="constantwithwarmup", | ||
train_batch_size = 4096, | ||
context_size = 128, | ||
lr_warm_up_steps=5000, | ||
|
||
# Activation Store Parameters | ||
n_batches_in_buffer = 128, | ||
total_training_tokens = 1_000_000 * 300, | ||
store_batch_size = 32, | ||
|
||
# Dead Neurons and Sparsity | ||
use_ghost_grads=True, | ||
feature_sampling_window = 1000, | ||
dead_feature_window=5000, | ||
dead_feature_threshold = 1e-6, | ||
|
||
# WANDB | ||
log_to_wandb = True, | ||
wandb_project= "mats_sae_training_gpt2", | ||
wandb_entity = None, | ||
wandb_log_frequency=100, | ||
|
||
# Misc | ||
device = "cuda", | ||
seed = 42, | ||
n_checkpoints = 10, | ||
checkpoint_path = "checkpoints", | ||
dtype = torch.float32, | ||
use_cached_activations = True, | ||
) | ||
|
||
sparse_autoencoder = language_model_sae_runner(cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from functools import partial | ||
import torch.nn.functional as F | ||
import dataclasses | ||
from sae_training.sparse_autoencoder import SparseAutoencoder | ||
|
||
|
||
class SAEGroup: | ||
def __init__(self, cfg): | ||
self.cfg = cfg | ||
self.autoencoders = [] # This will store tuples of (instance, hyperparameters) | ||
self._init_autoencoders(cfg) | ||
|
||
def _init_autoencoders(self, cfg): | ||
# Dynamically get all combinations of hyperparameters from cfg | ||
from itertools import product | ||
|
||
# Extract all hyperparameter lists from cfg | ||
hyperparameters = {k: v for k, v in vars(cfg).items() if isinstance(v, list)} | ||
keys, values = zip(*hyperparameters.items()) | ||
|
||
# Create all combinations of hyperparameters | ||
for combination in product(*values): | ||
cfg_copy = dataclasses.replace(cfg) | ||
params = dict(zip(keys, combination)) | ||
cfg_copy.update(params) | ||
# Insert the layer into the hookpoint | ||
cfg_copy.hook_point = cfg_copy.hook_point.format(layer=cfg.copy.hook_point_layer) | ||
# Create and store both the SparseAutoencoder instance and its parameters | ||
self.autoencoders.append((SparseAutoencoder(cfg_copy), cfg_copy)) | ||
|
||
def __iter__(self): | ||
# Make SAEGroup iterable over its SparseAutoencoder instances and their parameters | ||
for ae, params in self.autoencoders: | ||
yield ae, params # Yielding as a tuple | ||
|
||
def __len__(self): | ||
# Return the number of SparseAutoencoder instances | ||
return len(self.autoencoders) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.