Skip to content

Commit

Permalink
reduce hist freq, don't cap re-init
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 6, 2023
1 parent b63f14e commit debcf0f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def resample_neurons(

# Draw `n_hidden_ae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss
distn = Categorical(probs = per_token_l2_loss.pow(2) / (per_token_l2_loss.pow(2).sum()))
n_samples = min(n_dead, feature_sparsity.shape[-1] // self.cfg.expansion_factor) # don't reinit more than 10% of neurons at a time
n_samples = n_dead#min(n_dead, feature_sparsity.shape[-1] // self.cfg.expansion_factor) # don't reinit more than 10% of neurons at a time
replacement_indices = distn.sample((n_samples,)) # shape [n_dead]

# Index into the batch of hidden activations to get our replacement values
Expand Down
4 changes: 2 additions & 2 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def train_sae_on_language_model(
step=n_training_steps,
)

if (n_training_steps + 1) % (wandb_log_frequency * 10) == 0:
if (n_training_steps + 1) % (wandb_log_frequency * 100) == 0:
log_feature_sparsity = torch.log(feature_sparsity + 1e-8)
wandb.log(
{
Expand All @@ -137,7 +137,7 @@ def train_sae_on_language_model(
)

# Now we want the reconstruction loss.
recons_score, _, _, _ = get_recons_loss(sparse_autoencoder, model, activation_store, num_batches=5)
recons_score, _, _, _ = get_recons_loss(sparse_autoencoder, model, activation_store, num_batches=3)

wandb.log(
{
Expand Down

0 comments on commit debcf0f

Please sign in to comment.