Skip to content

Commit

Permalink
add ability to train on attn heads
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Dec 12, 2023
1 parent 9614a23 commit 18cfaad
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 38 deletions.
23 changes: 17 additions & 6 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,23 @@ def get_batch_tokens(self):
def get_activations(self, batch_tokens):

act_name = self.cfg.hook_point
activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_name,
)[
1
][act_name]
hook_point_layer = self.cfg.hook_point_layer
if self.cfg.hook_point_head_index is not None:
activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_name,
stop_at_layer=hook_point_layer
)[
1
][act_name][:,:,self.cfg.hook_point_head_index]
else:
activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_name,
stop_at_layer=hook_point_layer+1
)[
1
][act_name]

return activations

Expand Down
5 changes: 2 additions & 3 deletions sae_training/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from dataclasses import dataclass
from typing import Optional

import torch

Expand All @@ -16,6 +17,7 @@ class LanguageModelSAERunnerConfig:
model_name: str = "gelu-2l"
hook_point: str = "blocks.0.hook_mlp_out"
hook_point_layer: int = 0
hook_point_head_index: Optional[int] = None
dataset_path: str = "NeelNanda/c4-tokenized-2b"
is_dataset_tokenized: bool = True

Expand Down Expand Up @@ -66,9 +68,6 @@ def __post_init__(self):
unique_id = wandb.util.generate_id()
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

assert self.dead_feature_window < self.feature_sampling_window, "dead_feature_window must be < feature_sampling_window"


# Print out some useful info:
n_tokens_per_buffer = self.store_batch_size * self.context_size * self.n_batches_in_buffer
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
Expand Down
66 changes: 37 additions & 29 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,17 @@ def train_sae_on_language_model(
feature_reinit_scale,
optimizer
)

else:
# for all the dead neurons, set the feature sparsity to the dead feature threshold
act_freq_scores[is_dead] = sparse_autoencoder.cfg.dead_feature_threshold * n_frac_active_tokens
if n_resampled_neurons > 0:
print(f"Resampled {n_resampled_neurons} neurons")
if use_wandb:
wandb.log(
{
"metrics/n_resampled_neurons": n_resampled_neurons,
},
step=n_training_steps,
)
n_resampled_neurons = 0

# Update learning rate here if using scheduler.
Expand Down Expand Up @@ -124,39 +133,38 @@ def train_sae_on_language_model(
.mean()
.item(),
"details/n_training_tokens": n_training_tokens,
"metrics/n_resampled_neurons": n_resampled_neurons,
"metrics/current_learning_rate": current_learning_rate,
},
step=n_training_steps,
)

# record loss frequently, but not all the time.
if (n_training_steps + 1) % (wandb_log_frequency * 10) == 0:
# Now we want the reconstruction loss.
recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(sparse_autoencoder, model, activation_store, num_batches=3)
wandb.log(
{
"metrics/reconstruction_score": recons_score,
"metrics/ce_loss_without_sae": ntp_loss,
"metrics/ce_loss_with_sae": recons_loss,
"metrics/ce_loss_with_ablation": zero_abl_loss,
},
step=n_training_steps,
)
# record loss frequently, but not all the time.
if use_wandb and ((n_training_steps + 1) % (wandb_log_frequency * 10) == 0):
# Now we want the reconstruction loss.
recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(sparse_autoencoder, model, activation_store, num_batches=3)

wandb.log(
{
"metrics/reconstruction_score": recons_score,
"metrics/ce_loss_without_sae": ntp_loss,
"metrics/ce_loss_with_sae": recons_loss,
"metrics/ce_loss_with_ablation": zero_abl_loss,

},
step=n_training_steps,
)

# use feature window to log feature sparsity
if ((n_training_steps + 1) % feature_sampling_window == 0):
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10)
wandb.log(
{
"plots/feature_density_histogram": wandb.Histogram(
log_feature_sparsity.tolist()
),
},
step=n_training_steps,
)
# use feature window to log feature sparsity
if use_wandb and ((n_training_steps + 1) % feature_sampling_window == 0):
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10)
wandb.log(
{
"plots/feature_density_histogram": wandb.Histogram(
log_feature_sparsity.tolist()
),
},
step=n_training_steps,
)


pbar.set_description(
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/test_activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,49 @@ def cfg():
return mock_config



@pytest.fixture
def cfg_head_hook():
"""
Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig.
"""
# Create a mock object with the necessary attributes
mock_config = SimpleNamespace()
mock_config.model_name = TEST_MODEL
mock_config.hook_point = "blocks.0.attn.hook_q"
mock_config.hook_point_layer = 1
mock_config.hook_point_head_index = 2
mock_config.dataset_path = TEST_DATASET
mock_config.is_dataset_tokenized = False
mock_config.d_in = 4
mock_config.expansion_factor = 2
mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor
mock_config.l1_coefficient = 2e-3
mock_config.lr = 2e-4
mock_config.train_batch_size = 32
mock_config.context_size = 128

mock_config.feature_sampling_method = None
mock_config.feature_sampling_window = 50
mock_config.feature_reinit_scale = 0.1
mock_config.dead_feature_threshold = 1e-7

mock_config.n_batches_in_buffer = 4
mock_config.total_training_tokens = 1_000_000
mock_config.store_batch_size = 32

mock_config.log_to_wandb = False
mock_config.wandb_project = "test_project"
mock_config.wandb_entity = "test_entity"
mock_config.wandb_log_frequency = 10
mock_config.device = torch.device("cpu")
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
mock_config.dtype = torch.float32

return mock_config


@pytest.fixture
def model():
return HookedTransformer.from_pretrained(TEST_MODEL, device="cpu")
Expand All @@ -61,6 +104,10 @@ def model():
def activation_store(cfg, model):
return ActivationsStore(cfg, model)

@pytest.fixture
def activation_store_head_hook(cfg_head_hook, model):
return ActivationsStore(cfg_head_hook, model)

def test_activations_store__init__(cfg, model):

store = ActivationsStore(cfg, model)
Expand Down Expand Up @@ -100,6 +147,16 @@ def test_activations_store__get_activations(activation_store):
assert activations.shape == (cfg.store_batch_size, cfg.context_size, cfg.d_in)
assert activations.device == cfg.device

def test_activations_store__get_activations_head_hook(activation_store_head_hook):

batch = activation_store_head_hook.get_batch_tokens()
activations = activation_store_head_hook.get_activations(batch)

cfg = activation_store_head_hook.cfg
assert isinstance(activations, torch.Tensor)
assert activations.shape == (cfg.store_batch_size, cfg.context_size, cfg.d_in)
assert activations.device == cfg.device

def test_activations_store__get_buffer(activation_store):


Expand Down

0 comments on commit 18cfaad

Please sign in to comment.