Skip to content

Commit

Permalink
stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 2, 2023
1 parent afcc239 commit 19d278a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
5 changes: 5 additions & 0 deletions sae_training/lm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class LanguageModelSAERunnerConfig:
context_size: int = 128

# Resampling protocol args
feature_sampling_method: str = "l2" # None, l2, or anthropic
feature_sampling_window: int = 100
feature_reinit_scale: float = 0.2
dead_feature_threshold: float = 1e-8
Expand All @@ -61,6 +62,9 @@ class LanguageModelSAERunnerConfig:
def __post_init__(self):
self.d_sae = self.d_in * self.expansion_factor
self.tokens_per_buffer = self.train_batch_size * self.context_size * self.n_batches_in_buffer

if self.feature_sampling_method not in [None, "l2", "anthropic"]:
raise ValueError(f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}")

def language_model_sae_runner(cfg):

Expand All @@ -83,6 +87,7 @@ def language_model_sae_runner(cfg):
sparse_autoencoder = train_sae_on_language_model(
model, sparse_autoencoder, activations_buffer,
batch_size = cfg.train_batch_size,
feature_sampling_method = cfg.feature_sampling_method,
feature_sampling_window = cfg.feature_sampling_window,
feature_reinit_scale = cfg.feature_reinit_scale,
dead_feature_threshold = cfg.dead_feature_threshold,
Expand Down
5 changes: 3 additions & 2 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def train_sae_on_language_model(
sae: SAE,
data_loader_buffer,
batch_size: int = 1024,
feature_sampling_method: str = "l2", # None, l2, or anthropic
feature_sampling_window: int = 100, # how many training steps between resampling the features / considiring neurons dead
feature_reinit_scale: float = 0.2, # how much to scale the resampled features by
dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead
Expand Down Expand Up @@ -42,7 +43,7 @@ def train_sae_on_language_model(
sae.set_decoder_norm_to_unit_norm()

# Resample dead neurons
if (feature_sampling_window is not None) and ((n_training_steps + 1) % feature_sampling_window == 0):
if (feature_sampling_method is not None) and ((n_training_steps + 1) % feature_sampling_window == 0):

# Get the fraction of neurons active in the previous window
frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0)
Expand Down Expand Up @@ -167,7 +168,7 @@ def train_sae_on_language_model(
def get_new_dataloader(data_loader_buffer, n_remaining_batches_in_buffer, batch_size):
buffer = data_loader_buffer.get_buffer()
dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True))
n_remaining_batches_in_buffer = len(dataloader)
n_remaining_batches_in_buffer = len(dataloader) // 2 # only ever use half the buffer .
return dataloader, n_remaining_batches_in_buffer


Expand Down
15 changes: 8 additions & 7 deletions tests/benchmark/test_language_model_sae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ def test_language_model_sae_runner_mlp_out():

# Data Generating Function (Model + Training Distibuion)
model_name = "gelu-2l",
hook_point = "blocks.1.hook_mlp_out",
hook_point_layer = 0,
hook_point = "blocks.1.hook_resid_pre",
hook_point_layer = 1,
d_in = 512,
dataset_path = "NeelNanda/c4-tokenized-2b",
is_dataset_tokenized=True,

# SAE Parameters
expansion_factor = 64, # determines the dimension of the SAE.
expansion_factor = 32, # determines the dimension of the SAE.

# Training Parameters
lr = 1e-4,
Expand All @@ -29,18 +29,19 @@ def test_language_model_sae_runner_mlp_out():
context_size = 128,

# Activation Store Parameters
n_batches_in_buffer = 8,
total_training_tokens = 25_000_00 * 60,
n_batches_in_buffer = 16,
total_training_tokens = 25_000_00 * 15,
store_batch_size = 32,

# Resampling protocol
feature_sampling_window = 1000,
feature_sampling_method = None,
feature_sampling_window = 500,
feature_reinit_scale = 0.2,
dead_feature_threshold = 1e-8,

# WANDB
log_to_wandb = True,
wandb_project= "mats_sae_training_language_models",
wandb_project= "mats_sae_training_language_models_hack_day",
wandb_entity = None,

# Misc
Expand Down

0 comments on commit 19d278a

Please sign in to comment.