Skip to content

Commit

Permalink
start saving log sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Jan 28, 2024
1 parent e863ed7 commit 4d6df6f
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def train_sae_on_language_model(
if n_checkpoints > 0 and n_training_tokens > checkpoint_thresholds[0]:
cfg = sparse_autoencoder.cfg
path = f"{sparse_autoencoder.cfg.checkpoint_path}/{n_training_tokens}_{sparse_autoencoder.get_name()}.pt"
log_feature_sparsity_path = f"{sparse_autoencoder.cfg.checkpoint_path}/{n_training_tokens}_{sparse_autoencoder.get_name()}_log_feature_sparsity.pt"
sparse_autoencoder.save_model(path)
torch.save(log_feature_sparsity, log_feature_sparsity_path)
checkpoint_thresholds.pop(0)
if len(checkpoint_thresholds) == 0:
n_checkpoints = 0
Expand All @@ -235,6 +237,13 @@ def train_sae_on_language_model(
)
model_artifact.add_file(path)
wandb.log_artifact(model_artifact)

sparsity_artifact = wandb.Artifact(
f"{sparse_autoencoder.get_name()}_log_feature_sparsity", type="log_feature_sparsity", metadata=dict(cfg.__dict__)
)
sparsity_artifact.add_file(log_feature_sparsity_path)
wandb.log_artifact(sparsity_artifact)


n_training_steps += 1

Expand Down

0 comments on commit 4d6df6f

Please sign in to comment.