diff --git a/sae_training/lm_runner.py b/sae_training/lm_runner.py index 4696bf80..835ebb55 100644 --- a/sae_training/lm_runner.py +++ b/sae_training/lm_runner.py @@ -36,6 +36,7 @@ class LanguageModelSAERunnerConfig: context_size: int = 128 # Resampling protocol args + feature_sampling_method: str = "l2" # None, l2, or anthropic feature_sampling_window: int = 100 feature_reinit_scale: float = 0.2 dead_feature_threshold: float = 1e-8 @@ -61,6 +62,9 @@ class LanguageModelSAERunnerConfig: def __post_init__(self): self.d_sae = self.d_in * self.expansion_factor self.tokens_per_buffer = self.train_batch_size * self.context_size * self.n_batches_in_buffer + + if self.feature_sampling_method not in [None, "l2", "anthropic"]: + raise ValueError(f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}") def language_model_sae_runner(cfg): @@ -83,6 +87,7 @@ def language_model_sae_runner(cfg): sparse_autoencoder = train_sae_on_language_model( model, sparse_autoencoder, activations_buffer, batch_size = cfg.train_batch_size, + feature_sampling_method = cfg.feature_sampling_method, feature_sampling_window = cfg.feature_sampling_window, feature_reinit_scale = cfg.feature_reinit_scale, dead_feature_threshold = cfg.dead_feature_threshold, diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index be02bfa0..abc12764 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -15,6 +15,7 @@ def train_sae_on_language_model( sae: SAE, data_loader_buffer, batch_size: int = 1024, + feature_sampling_method: str = "l2", # None, l2, or anthropic feature_sampling_window: int = 100, # how many training steps between resampling the features / considiring neurons dead feature_reinit_scale: float = 0.2, # how much to scale the resampled features by dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead @@ -42,7 +43,7 @@ def train_sae_on_language_model( sae.set_decoder_norm_to_unit_norm() # Resample dead neurons - if (feature_sampling_window is not None) and ((n_training_steps + 1) % feature_sampling_window == 0): + if (feature_sampling_method is not None) and ((n_training_steps + 1) % feature_sampling_window == 0): # Get the fraction of neurons active in the previous window frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0) @@ -167,7 +168,7 @@ def train_sae_on_language_model( def get_new_dataloader(data_loader_buffer, n_remaining_batches_in_buffer, batch_size): buffer = data_loader_buffer.get_buffer() dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True)) - n_remaining_batches_in_buffer = len(dataloader) + n_remaining_batches_in_buffer = len(dataloader) // 2 # only ever use half the buffer . return dataloader, n_remaining_batches_in_buffer diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index 3d7e6e55..674622b5 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -13,14 +13,14 @@ def test_language_model_sae_runner_mlp_out(): # Data Generating Function (Model + Training Distibuion) model_name = "gelu-2l", - hook_point = "blocks.1.hook_mlp_out", - hook_point_layer = 0, + hook_point = "blocks.1.hook_resid_pre", + hook_point_layer = 1, d_in = 512, dataset_path = "NeelNanda/c4-tokenized-2b", is_dataset_tokenized=True, # SAE Parameters - expansion_factor = 64, # determines the dimension of the SAE. + expansion_factor = 32, # determines the dimension of the SAE. # Training Parameters lr = 1e-4, @@ -29,18 +29,19 @@ def test_language_model_sae_runner_mlp_out(): context_size = 128, # Activation Store Parameters - n_batches_in_buffer = 8, - total_training_tokens = 25_000_00 * 60, + n_batches_in_buffer = 16, + total_training_tokens = 25_000_00 * 15, store_batch_size = 32, # Resampling protocol - feature_sampling_window = 1000, + feature_sampling_method = None, + feature_sampling_window = 500, feature_reinit_scale = 0.2, dead_feature_threshold = 1e-8, # WANDB log_to_wandb = True, - wandb_project= "mats_sae_training_language_models", + wandb_project= "mats_sae_training_language_models_hack_day", wandb_entity = None, # Misc