diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index bbb0621f..6964ccfe 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -110,12 +110,23 @@ def get_batch_tokens(self): def get_activations(self, batch_tokens): act_name = self.cfg.hook_point - activations = self.model.run_with_cache( - batch_tokens, - names_filter=act_name, - )[ - 1 - ][act_name] + hook_point_layer = self.cfg.hook_point_layer + if self.cfg.hook_point_head_index is not None: + activations = self.model.run_with_cache( + batch_tokens, + names_filter=act_name, + stop_at_layer=hook_point_layer + )[ + 1 + ][act_name][:,:,self.cfg.hook_point_head_index] + else: + activations = self.model.run_with_cache( + batch_tokens, + names_filter=act_name, + stop_at_layer=hook_point_layer+1 + )[ + 1 + ][act_name] return activations diff --git a/sae_training/config.py b/sae_training/config.py index cd87d7bb..8d95d0dd 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -1,5 +1,6 @@ from dataclasses import dataclass +from typing import Optional import torch @@ -16,6 +17,7 @@ class LanguageModelSAERunnerConfig: model_name: str = "gelu-2l" hook_point: str = "blocks.0.hook_mlp_out" hook_point_layer: int = 0 + hook_point_head_index: Optional[int] = None dataset_path: str = "NeelNanda/c4-tokenized-2b" is_dataset_tokenized: bool = True @@ -66,9 +68,6 @@ def __post_init__(self): unique_id = wandb.util.generate_id() self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}" - assert self.dead_feature_window < self.feature_sampling_window, "dead_feature_window must be < feature_sampling_window" - - # Print out some useful info: n_tokens_per_buffer = self.store_batch_size * self.context_size * self.n_batches_in_buffer print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}") diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index 9c761010..c9e3685b 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -75,8 +75,17 @@ def train_sae_on_language_model( feature_reinit_scale, optimizer ) - - else: + # for all the dead neurons, set the feature sparsity to the dead feature threshold + act_freq_scores[is_dead] = sparse_autoencoder.cfg.dead_feature_threshold * n_frac_active_tokens + if n_resampled_neurons > 0: + print(f"Resampled {n_resampled_neurons} neurons") + if use_wandb: + wandb.log( + { + "metrics/n_resampled_neurons": n_resampled_neurons, + }, + step=n_training_steps, + ) n_resampled_neurons = 0 # Update learning rate here if using scheduler. @@ -124,39 +133,38 @@ def train_sae_on_language_model( .mean() .item(), "details/n_training_tokens": n_training_tokens, - "metrics/n_resampled_neurons": n_resampled_neurons, "metrics/current_learning_rate": current_learning_rate, }, step=n_training_steps, ) - # record loss frequently, but not all the time. - if (n_training_steps + 1) % (wandb_log_frequency * 10) == 0: - # Now we want the reconstruction loss. - recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(sparse_autoencoder, model, activation_store, num_batches=3) - - wandb.log( - { - "metrics/reconstruction_score": recons_score, - "metrics/ce_loss_without_sae": ntp_loss, - "metrics/ce_loss_with_sae": recons_loss, - "metrics/ce_loss_with_ablation": zero_abl_loss, - - }, - step=n_training_steps, - ) + # record loss frequently, but not all the time. + if use_wandb and ((n_training_steps + 1) % (wandb_log_frequency * 10) == 0): + # Now we want the reconstruction loss. + recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(sparse_autoencoder, model, activation_store, num_batches=3) + + wandb.log( + { + "metrics/reconstruction_score": recons_score, + "metrics/ce_loss_without_sae": ntp_loss, + "metrics/ce_loss_with_sae": recons_loss, + "metrics/ce_loss_with_ablation": zero_abl_loss, + + }, + step=n_training_steps, + ) - # use feature window to log feature sparsity - if ((n_training_steps + 1) % feature_sampling_window == 0): - log_feature_sparsity = torch.log10(feature_sparsity + 1e-10) - wandb.log( - { - "plots/feature_density_histogram": wandb.Histogram( - log_feature_sparsity.tolist() - ), - }, - step=n_training_steps, - ) + # use feature window to log feature sparsity + if use_wandb and ((n_training_steps + 1) % feature_sampling_window == 0): + log_feature_sparsity = torch.log10(feature_sparsity + 1e-10) + wandb.log( + { + "plots/feature_density_histogram": wandb.Histogram( + log_feature_sparsity.tolist() + ), + }, + step=n_training_steps, + ) pbar.set_description( diff --git a/tests/unit/test_activations_store.py b/tests/unit/test_activations_store.py index d7a7a780..5af0fe87 100644 --- a/tests/unit/test_activations_store.py +++ b/tests/unit/test_activations_store.py @@ -53,6 +53,49 @@ def cfg(): return mock_config + +@pytest.fixture +def cfg_head_hook(): + """ + Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. + """ + # Create a mock object with the necessary attributes + mock_config = SimpleNamespace() + mock_config.model_name = TEST_MODEL + mock_config.hook_point = "blocks.0.attn.hook_q" + mock_config.hook_point_layer = 1 + mock_config.hook_point_head_index = 2 + mock_config.dataset_path = TEST_DATASET + mock_config.is_dataset_tokenized = False + mock_config.d_in = 4 + mock_config.expansion_factor = 2 + mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor + mock_config.l1_coefficient = 2e-3 + mock_config.lr = 2e-4 + mock_config.train_batch_size = 32 + mock_config.context_size = 128 + + mock_config.feature_sampling_method = None + mock_config.feature_sampling_window = 50 + mock_config.feature_reinit_scale = 0.1 + mock_config.dead_feature_threshold = 1e-7 + + mock_config.n_batches_in_buffer = 4 + mock_config.total_training_tokens = 1_000_000 + mock_config.store_batch_size = 32 + + mock_config.log_to_wandb = False + mock_config.wandb_project = "test_project" + mock_config.wandb_entity = "test_entity" + mock_config.wandb_log_frequency = 10 + mock_config.device = torch.device("cpu") + mock_config.seed = 24 + mock_config.checkpoint_path = "test/checkpoints" + mock_config.dtype = torch.float32 + + return mock_config + + @pytest.fixture def model(): return HookedTransformer.from_pretrained(TEST_MODEL, device="cpu") @@ -61,6 +104,10 @@ def model(): def activation_store(cfg, model): return ActivationsStore(cfg, model) +@pytest.fixture +def activation_store_head_hook(cfg_head_hook, model): + return ActivationsStore(cfg_head_hook, model) + def test_activations_store__init__(cfg, model): store = ActivationsStore(cfg, model) @@ -100,6 +147,16 @@ def test_activations_store__get_activations(activation_store): assert activations.shape == (cfg.store_batch_size, cfg.context_size, cfg.d_in) assert activations.device == cfg.device +def test_activations_store__get_activations_head_hook(activation_store_head_hook): + + batch = activation_store_head_hook.get_batch_tokens() + activations = activation_store_head_hook.get_activations(batch) + + cfg = activation_store_head_hook.cfg + assert isinstance(activations, torch.Tensor) + assert activations.shape == (cfg.store_batch_size, cfg.context_size, cfg.d_in) + assert activations.device == cfg.device + def test_activations_store__get_buffer(activation_store):