diff --git a/sae_training/sparse_autoencoder.py b/sae_training/sparse_autoencoder.py index 9710827d..d55b5e6a 100644 --- a/sae_training/sparse_autoencoder.py +++ b/sae_training/sparse_autoencoder.py @@ -5,7 +5,7 @@ import gzip import os import pickle -from typing import Any +from typing import Any, NamedTuple import einops import torch @@ -16,6 +16,15 @@ from sae_training.geometric_median import compute_geometric_median +class ForwardOutput(NamedTuple): + sae_out: torch.Tensor + feature_acts: torch.Tensor + loss: torch.Tensor + mse_loss: torch.Tensor + l1_loss: torch.Tensor + ghost_grad_loss: torch.Tensor + + class SparseAutoencoder(HookedRootModule): """ """ @@ -138,7 +147,14 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None) l1_loss = self.l1_coefficient * sparsity loss = mse_loss + l1_loss + mse_loss_ghost_resid - return sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid + return ForwardOutput( + sae_out=sae_out, + feature_acts=feature_acts, + loss=loss, + mse_loss=mse_loss, + l1_loss=l1_loss, + ghost_grad_loss=mse_loss_ghost_resid, + ) @torch.no_grad() def initialize_b_dec_with_precalculated(self, origin: torch.Tensor): diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index 2d56cadf..5e4ae915 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -1,7 +1,9 @@ -from typing import Any, cast +from dataclasses import dataclass +from typing import Any, NamedTuple, cast import torch -from torch.optim import Adam +from torch.optim import Adam, Optimizer +from torch.optim.lr_scheduler import LRScheduler from tqdm import tqdm from transformer_lens import HookedTransformer @@ -11,6 +13,31 @@ from sae_training.geometric_median import compute_geometric_median from sae_training.optim import get_scheduler from sae_training.sae_group import SAEGroup +from sae_training.sparse_autoencoder import SparseAutoencoder + + +@dataclass +class SAETrainContext: + """ + Context to track during training for a single SAE + """ + + act_freq_scores: torch.Tensor + n_forward_passes_since_fired: torch.Tensor + n_frac_active_tokens: int + optimizer: Optimizer + scheduler: LRScheduler + + @property + def feature_sparsity(self) -> torch.Tensor: + return self.act_freq_scores / self.n_frac_active_tokens + + +@dataclass +class TrainSAEGroupOutput: + sae_group: SAEGroup + checkpoint_paths: list[str] + log_feature_sparsities: list[torch.Tensor] def train_sae_on_language_model( @@ -23,12 +50,36 @@ def train_sae_on_language_model( dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead use_wandb: bool = False, wandb_log_frequency: int = 50, -): +) -> SAEGroup: + """ + @deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility. + """ + return train_sae_group_on_language_model( + model, + sae_group, + activation_store, + batch_size, + n_checkpoints, + feature_sampling_window, + use_wandb, + wandb_log_frequency, + ).sae_group + + +def train_sae_group_on_language_model( + model: HookedTransformer, + sae_group: SAEGroup, + activation_store: ActivationsStore, + batch_size: int = 1024, + n_checkpoints: int = 0, + feature_sampling_window: int = 1000, # how many training steps between resampling the features / considiring neurons dead + use_wandb: bool = False, + wandb_log_frequency: int = 50, +) -> TrainSAEGroupOutput: total_training_tokens = sae_group.cfg.total_training_tokens total_training_steps = total_training_tokens // batch_size n_training_steps = 0 n_training_tokens = 0 - log_feature_sparsity = None checkpoint_thresholds = [] if n_checkpoints > 0: @@ -36,47 +87,159 @@ def train_sae_on_language_model( range(0, total_training_tokens, total_training_tokens // n_checkpoints) )[1:] - # things to store for each sae: - # act_freq_scores, n_forward_passes_since_fired, n_frac_active_tokens, optimizer, scheduler, - num_saes = len(sae_group) - # track active features + all_layers = sae_group.cfg.hook_point_layer + if not isinstance(all_layers, list): + all_layers = [all_layers] - act_freq_scores = [ - torch.zeros( - cast(int, sparse_autoencoder.cfg.d_sae), - device=sparse_autoencoder.cfg.device, - ) - for sparse_autoencoder in sae_group + wandb_suffix = _wandb_log_suffix(sae_group.cfg, sae_group.cfg) + train_contexts = [ + _build_train_context(sae, total_training_steps) for sae in sae_group ] - n_forward_passes_since_fired = [ - torch.zeros( - cast(int, sparse_autoencoder.cfg.d_sae), - device=sparse_autoencoder.cfg.device, - ) - for sparse_autoencoder in sae_group - ] - n_frac_active_tokens = [0 for _ in range(num_saes)] - - optimizer = [Adam(sae.parameters(), lr=sae.cfg.lr) for sae in sae_group] - scheduler = [ - get_scheduler( - sae.cfg.lr_scheduler_name, - optimizer=opt, - warm_up_steps=sae.cfg.lr_warm_up_steps, - training_steps=total_training_steps, - lr_end=sae.cfg.lr / 10, # heuristic for now. + _init_sae_group_b_decs(sae_group, activation_store, all_layers) + + pbar = tqdm(total=total_training_tokens, desc="Training SAE") + checkpoint_paths: list[str] = [] + while n_training_tokens < total_training_tokens: + # Do a training step. + layer_acts = activation_store.next_batch() + n_training_tokens += batch_size + + mse_losses: list[torch.Tensor] = [] + l1_losses: list[torch.Tensor] = [] + + for ( + sparse_autoencoder, + ctx, + ) in zip(sae_group, train_contexts): + step_output = _train_step( + sparse_autoencoder=sparse_autoencoder, + layer_acts=layer_acts, + ctx=ctx, + feature_sampling_window=feature_sampling_window, + use_wandb=use_wandb, + n_training_steps=n_training_steps, + all_layers=all_layers, + batch_size=batch_size, + wandb_suffix=wandb_suffix, + ) + mse_losses.append(step_output.mse_loss) + l1_losses.append(step_output.l1_loss) + if use_wandb: + with torch.no_grad(): + if (n_training_steps + 1) % wandb_log_frequency == 0: + wandb.log( + _build_train_step_log_dict( + sparse_autoencoder, + step_output, + ctx, + wandb_suffix, + n_training_tokens, + ), + step=n_training_steps, + ) + + # record loss frequently, but not all the time. + if (n_training_steps + 1) % (wandb_log_frequency * 10) == 0: + sparse_autoencoder.eval() + run_evals( + sparse_autoencoder, + activation_store, + model, + n_training_steps, + suffix=wandb_suffix, + ) + sparse_autoencoder.train() + + # checkpoint if at checkpoint frequency + if checkpoint_thresholds and n_training_tokens > checkpoint_thresholds[0]: + checkpoint_path = _save_checkpoint( + sae_group, + train_contexts=train_contexts, + checkpoint_name=n_training_tokens, + ).path + checkpoint_paths.append(checkpoint_path) + checkpoint_thresholds.pop(0) + + ############### + + n_training_steps += 1 + pbar.set_description( + f"{n_training_steps}| MSE Loss {torch.stack(mse_losses).mean().item():.3f} | L1 {torch.stack(l1_losses).mean().item():.3f}" ) - for sae, opt in zip(sae_group, optimizer) - ] + pbar.update(batch_size) + + # save final sae group to checkpoints folder + final_checkpoint = _save_checkpoint( + sae_group, + train_contexts=train_contexts, + checkpoint_name="final", + wandb_aliases=["final_model"], + ) + checkpoint_paths.append(final_checkpoint.path) + + return TrainSAEGroupOutput( + sae_group=sae_group, + checkpoint_paths=checkpoint_paths, + log_feature_sparsities=final_checkpoint.log_feature_sparsities, + ) - all_layers = sae_group.cfg.hook_point_layer - if not isinstance(all_layers, list): - all_layers = [all_layers] - # compute the geometric median of the activations of each layer +def _wandb_log_suffix(cfg: Any, hyperparams: Any): + # Create a mapping from cfg list keys to their corresponding hyperparams attributes + key_mapping = { + "hook_point_layer": "layer", + "l1_coefficient": "coeff", + "lp_norm": "l", + "lr": "lr", + } + # Generate the suffix by iterating over the keys that have list values in cfg + suffix = "".join( + f"_{key_mapping.get(key, key)}{getattr(hyperparams, key, '')}" + for key, value in vars(cfg).items() + if isinstance(value, list) + ) + return suffix + + +def _build_train_context( + sae: SparseAutoencoder, total_training_steps: int +) -> SAETrainContext: + act_freq_scores = torch.zeros( + cast(int, sae.cfg.d_sae), + device=sae.cfg.device, + ) + n_forward_passes_since_fired = torch.zeros( + cast(int, sae.cfg.d_sae), + device=sae.cfg.device, + ) + n_frac_active_tokens = 0 + + optimizer = Adam(sae.parameters(), lr=sae.cfg.lr) + scheduler = get_scheduler( + sae.cfg.lr_scheduler_name, + optimizer=optimizer, + warm_up_steps=sae.cfg.lr_warm_up_steps, + training_steps=total_training_steps, + lr_end=sae.cfg.lr / 10, # heuristic for now. + ) + + return SAETrainContext( + act_freq_scores=act_freq_scores, + n_forward_passes_since_fired=n_forward_passes_since_fired, + n_frac_active_tokens=n_frac_active_tokens, + optimizer=optimizer, + scheduler=scheduler, + ) + + +def _init_sae_group_b_decs( + sae_group: SAEGroup, activation_store: ActivationsStore, all_layers: list[int] +) -> None: + """ + extract all activations at a certain layer and use for sae b_dec initialization + """ geometric_medians = {} - # extract all activations at a certain layer and use for sae initialization for sae in sae_group: hyperparams = sae.cfg sae_layer_id = all_layers.index(hyperparams.hook_point_layer) @@ -95,196 +258,175 @@ def train_sae_on_language_model( :, sae_layer_id, : ] sae.initialize_b_dec_with_mean(layer_acts) - sae.train() - - pbar = tqdm(total=total_training_tokens, desc="Training SAE") - while n_training_tokens < total_training_tokens: - # Do a training step. - layer_acts = activation_store.next_batch() - n_training_tokens += batch_size - # init these here to avoid uninitialized vars - mse_loss = torch.tensor(0.0) - l1_loss = torch.tensor(0.0) - for ( - i, - (sparse_autoencoder), - ) in enumerate(sae_group): - assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy - hyperparams = sparse_autoencoder.cfg - layer_id = all_layers.index(hyperparams.hook_point_layer) - sae_in = layer_acts[:, layer_id, :] - - sparse_autoencoder.train() - # Make sure the W_dec is still zero-norm - sparse_autoencoder.set_decoder_norm_to_unit_norm() - - # log and then reset the feature sparsity every feature_sampling_window steps - if (n_training_steps + 1) % feature_sampling_window == 0: - feature_sparsity = act_freq_scores[i] / n_frac_active_tokens[i] - log_feature_sparsity = ( - torch.log10(feature_sparsity + 1e-10).detach().cpu() - ) - - if use_wandb: - suffix = wandb_log_suffix(sae_group.cfg, hyperparams) - wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) - wandb.log( - { - f"metrics/mean_log10_feature_sparsity{suffix}": log_feature_sparsity.mean().item(), - f"plots/feature_density_line_chart{suffix}": wandb_histogram, - f"sparsity/below_1e-5{suffix}": (feature_sparsity < 1e-5) - .sum() - .item(), - f"sparsity/below_1e-6{suffix}": (feature_sparsity < 1e-6) - .sum() - .item(), - }, - step=n_training_steps, - ) - - act_freq_scores[i] = torch.zeros( - sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device - ) - n_frac_active_tokens[i] = 0 - - scheduler[i].step() - optimizer[i].zero_grad() - - ghost_grad_neuron_mask = ( - n_forward_passes_since_fired[i] - > sparse_autoencoder.cfg.dead_feature_window - ).bool() - - # Forward and Backward Passes - ( - sae_out, - feature_acts, - loss, - mse_loss, - l1_loss, - ghost_grad_loss, - ) = sparse_autoencoder( - sae_in, - ghost_grad_neuron_mask, +@dataclass +class TrainStepOutput: + sae_in: torch.Tensor + sae_out: torch.Tensor + feature_acts: torch.Tensor + loss: torch.Tensor + mse_loss: torch.Tensor + l1_loss: torch.Tensor + ghost_grad_loss: torch.Tensor + ghost_grad_neuron_mask: torch.Tensor + + +def _train_step( + sparse_autoencoder: SparseAutoencoder, + layer_acts: torch.Tensor, + ctx: SAETrainContext, + feature_sampling_window: int, # how many training steps between resampling the features / considiring neurons dead + use_wandb: bool, + n_training_steps: int, + all_layers: list[int], + batch_size: int, + wandb_suffix: str, +) -> TrainStepOutput: + assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy + hyperparams = sparse_autoencoder.cfg + layer_id = all_layers.index(hyperparams.hook_point_layer) + sae_in = layer_acts[:, layer_id, :] + + sparse_autoencoder.train() + # Make sure the W_dec is still zero-norm + sparse_autoencoder.set_decoder_norm_to_unit_norm() + + # log and then reset the feature sparsity every feature_sampling_window steps + if (n_training_steps + 1) % feature_sampling_window == 0: + feature_sparsity = ctx.feature_sparsity + log_feature_sparsity = _log_feature_sparsity(feature_sparsity) + + if use_wandb: + wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) + wandb.log( + { + f"metrics/mean_log10_feature_sparsity{wandb_suffix}": log_feature_sparsity.mean().item(), + f"plots/feature_density_line_chart{wandb_suffix}": wandb_histogram, + f"sparsity/below_1e-5{wandb_suffix}": (feature_sparsity < 1e-5) + .sum() + .item(), + f"sparsity/below_1e-6{wandb_suffix}": (feature_sparsity < 1e-6) + .sum() + .item(), + }, + step=n_training_steps, ) - did_fire = (feature_acts > 0).float().sum(-2) > 0 - n_forward_passes_since_fired[i] += 1 - n_forward_passes_since_fired[i][did_fire] = 0 - - with torch.no_grad(): - # Calculate the sparsities, and add it to a list, calculate sparsity metrics - act_freq_scores[i] += (feature_acts.abs() > 0).float().sum(0) - n_frac_active_tokens[i] += batch_size - feature_sparsity = act_freq_scores[i] / n_frac_active_tokens[i] - - if use_wandb and ((n_training_steps + 1) % wandb_log_frequency == 0): - # metrics for currents acts - l0 = (feature_acts > 0).float().sum(-1).mean() - current_learning_rate = optimizer[i].param_groups[0]["lr"] - - per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze() - total_variance = (sae_in - sae_in.mean(0)).pow(2).sum(-1) - explained_variance = 1 - per_token_l2_loss / total_variance - - suffix = wandb_log_suffix(sae_group.cfg, hyperparams) - wandb.log( - { - # losses - f"losses/mse_loss{suffix}": mse_loss.item(), - f"losses/l1_loss{suffix}": l1_loss.item() - / sparse_autoencoder.l1_coefficient, # normalize by l1 coefficient - f"losses/ghost_grad_loss{suffix}": ghost_grad_loss.item(), - f"losses/overall_loss{suffix}": loss.item(), - # variance explained - f"metrics/explained_variance{suffix}": explained_variance.mean().item(), - f"metrics/explained_variance_std{suffix}": explained_variance.std().item(), - f"metrics/l0{suffix}": l0.item(), - # sparsity - f"sparsity/mean_passes_since_fired{suffix}": n_forward_passes_since_fired[ - i - ] - .mean() - .item(), - f"sparsity/dead_features{suffix}": ghost_grad_neuron_mask.sum().item(), - f"details/current_learning_rate{suffix}": current_learning_rate, - "details/n_training_tokens": n_training_tokens, - }, - step=n_training_steps, - ) - - # record loss frequently, but not all the time. - if use_wandb and ( - (n_training_steps + 1) % (wandb_log_frequency * 10) == 0 - ): - sparse_autoencoder.eval() - suffix = wandb_log_suffix(sae_group.cfg, hyperparams) - run_evals( - sparse_autoencoder, - activation_store, - model, - n_training_steps, - suffix=suffix, - ) - sparse_autoencoder.train() - - loss.backward() - sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() - optimizer[i].step() - # checkpoint if at checkpoint frequency - if n_checkpoints > 0 and n_training_tokens > checkpoint_thresholds[0]: - path = f"{sae_group.cfg.checkpoint_path}/{n_training_tokens}_{sae_group.get_name()}.pt" - for sae in sae_group: - sae.set_decoder_norm_to_unit_norm() - sae_group.save_model(path) - - log_feature_sparsity_path = f"{sae_group.cfg.checkpoint_path}/{n_training_tokens}_{sae_group.get_name()}_log_feature_sparsity.pt" - log_feature_sparsity = [] - for sae_id in range(len(sae_group)): - feature_sparsity = ( - act_freq_scores[sae_id] / n_frac_active_tokens[sae_id] - ) - log_feature_sparsity.append( - torch.log10(feature_sparsity + 1e-10).detach().cpu() - ) - torch.save(log_feature_sparsity, log_feature_sparsity_path) + ctx.act_freq_scores = torch.zeros( + sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device + ) + ctx.n_frac_active_tokens = 0 + + ghost_grad_neuron_mask = ( + ctx.n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window + ).bool() + + # Forward and Backward Passes + ( + sae_out, + feature_acts, + loss, + mse_loss, + l1_loss, + ghost_grad_loss, + ) = sparse_autoencoder( + sae_in, + ghost_grad_neuron_mask, + ) + did_fire = (feature_acts > 0).float().sum(-2) > 0 + ctx.n_forward_passes_since_fired += 1 + ctx.n_forward_passes_since_fired[did_fire] = 0 + + with torch.no_grad(): + # Calculate the sparsities, and add it to a list, calculate sparsity metrics + ctx.act_freq_scores += (feature_acts.abs() > 0).float().sum(0) + ctx.n_frac_active_tokens += batch_size + + ctx.optimizer.zero_grad() + loss.backward() + sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() + ctx.optimizer.step() + ctx.scheduler.step() + + return TrainStepOutput( + sae_in=sae_in, + sae_out=sae_out, + feature_acts=feature_acts, + loss=loss, + mse_loss=mse_loss, + l1_loss=l1_loss, + ghost_grad_loss=ghost_grad_loss, + ghost_grad_neuron_mask=ghost_grad_neuron_mask, + ) - checkpoint_thresholds.pop(0) - if len(checkpoint_thresholds) == 0: - n_checkpoints = 0 - if sae_group.cfg.log_to_wandb: - model_artifact = wandb.Artifact( - f"{sae_group.get_name()}", - type="model", - metadata=dict(sae_group.cfg.__dict__), - ) - model_artifact.add_file(path) - wandb.log_artifact(model_artifact) - - sparsity_artifact = wandb.Artifact( - f"{sae_group.get_name()}_log_feature_sparsity", - type="log_feature_sparsity", - metadata=dict(sae_group.cfg.__dict__), - ) - sparsity_artifact.add_file(log_feature_sparsity_path) - wandb.log_artifact(sparsity_artifact) - - ############### - n_training_steps += 1 - pbar.set_description( - f"{n_training_steps}| MSE Loss {mse_loss.item():.3f} | L1 {l1_loss.item():.3f}" - ) - pbar.update(batch_size) +def _build_train_step_log_dict( + sparse_autoencoder: SparseAutoencoder, + output: TrainStepOutput, + ctx: SAETrainContext, + wandb_suffix: str, + n_training_tokens: int, +) -> dict[str, Any]: + sae_in = output.sae_in + sae_out = output.sae_out + feature_acts = output.feature_acts + mse_loss = output.mse_loss + l1_loss = output.l1_loss + ghost_grad_loss = output.ghost_grad_loss + loss = output.loss + ghost_grad_neuron_mask = output.ghost_grad_neuron_mask + + # metrics for currents acts + l0 = (feature_acts > 0).float().sum(-1).mean() + current_learning_rate = ctx.optimizer.param_groups[0]["lr"] + + per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze() + total_variance = (sae_in - sae_in.mean(0)).pow(2).sum(-1) + explained_variance = 1 - per_token_l2_loss / total_variance + + return { + # losses + f"losses/mse_loss{wandb_suffix}": mse_loss.item(), + f"losses/l1_loss{wandb_suffix}": l1_loss.item() + / sparse_autoencoder.l1_coefficient, # normalize by l1 coefficient + f"losses/ghost_grad_loss{wandb_suffix}": ghost_grad_loss.item(), + f"losses/overall_loss{wandb_suffix}": loss.item(), + # variance explained + f"metrics/explained_variance{wandb_suffix}": explained_variance.mean().item(), + f"metrics/explained_variance_std{wandb_suffix}": explained_variance.std().item(), + f"metrics/l0{wandb_suffix}": l0.item(), + # sparsity + f"sparsity/mean_passes_since_fired{wandb_suffix}": ctx.n_forward_passes_since_fired.mean().item(), + f"sparsity/dead_features{wandb_suffix}": ghost_grad_neuron_mask.sum().item(), + f"details/current_learning_rate{wandb_suffix}": current_learning_rate, + "details/n_training_tokens": n_training_tokens, + } + + +class SaveCheckpointOutput(NamedTuple): + path: str + log_feature_sparsity_path: str + log_feature_sparsities: list[torch.Tensor] - # save sae group to checkpoints folder - path = f"{sae_group.cfg.checkpoint_path}/final_{sae_group.get_name()}.pt" + +def _save_checkpoint( + sae_group: SAEGroup, + train_contexts: list[SAETrainContext], + checkpoint_name: int | str, + wandb_aliases: list[str] | None = None, +) -> SaveCheckpointOutput: + path = ( + f"{sae_group.cfg.checkpoint_path}/{checkpoint_name}_{sae_group.get_name()}.pt" + ) for sae in sae_group: sae.set_decoder_norm_to_unit_norm() sae_group.save_model(path) - + log_feature_sparsity_path = f"{sae_group.cfg.checkpoint_path}/{checkpoint_name}_{sae_group.get_name()}_log_feature_sparsity.pt" + log_feature_sparsities = [ + _log_feature_sparsity(ctx.feature_sparsity) for ctx in train_contexts + ] + torch.save(log_feature_sparsities, log_feature_sparsity_path) if sae_group.cfg.log_to_wandb: model_artifact = wandb.Artifact( f"{sae_group.get_name()}", @@ -292,19 +434,8 @@ def train_sae_on_language_model( metadata=dict(sae_group.cfg.__dict__), ) model_artifact.add_file(path) - wandb.log_artifact(model_artifact, aliases=["final_model"]) - - # need to fix this - log_feature_sparsity_path = f"{sae_group.cfg.checkpoint_path}/final_{sae_group.get_name()}_log_feature_sparsity.pt" - log_feature_sparsity = [] - for sae_id in range(len(sae_group)): - feature_sparsity = act_freq_scores[sae_id] / n_frac_active_tokens[sae_id] - log_feature_sparsity.append( - torch.log10(feature_sparsity + 1e-10).detach().cpu() - ) - torch.save(log_feature_sparsity, log_feature_sparsity_path) + wandb.log_artifact(model_artifact, aliases=wandb_aliases) - if sae_group.cfg.log_to_wandb: sparsity_artifact = wandb.Artifact( f"{sae_group.get_name()}_log_feature_sparsity", type="log_feature_sparsity", @@ -312,23 +443,10 @@ def train_sae_on_language_model( ) sparsity_artifact.add_file(log_feature_sparsity_path) wandb.log_artifact(sparsity_artifact) + return SaveCheckpointOutput(path, log_feature_sparsity_path, log_feature_sparsities) - return sae_group - -def wandb_log_suffix(cfg: Any, hyperparams: Any): - # Create a mapping from cfg list keys to their corresponding hyperparams attributes - key_mapping = { - "hook_point_layer": "layer", - "l1_coefficient": "coeff", - "lp_norm": "l", - "lr": "lr", - } - - # Generate the suffix by iterating over the keys that have list values in cfg - suffix = "".join( - f"_{key_mapping.get(key, key)}{getattr(hyperparams, key, '')}" - for key, value in vars(cfg).items() - if isinstance(value, list) - ) - return suffix +def _log_feature_sparsity( + feature_sparsity: torch.Tensor, eps: float = 1e-10 +) -> torch.Tensor: + return torch.log10(feature_sparsity + eps).detach().cpu() diff --git a/tests/unit/test_train_sae_on_language_model.py b/tests/unit/test_train_sae_on_language_model.py new file mode 100644 index 00000000..096185cf --- /dev/null +++ b/tests/unit/test_train_sae_on_language_model.py @@ -0,0 +1,330 @@ +from pathlib import Path +from typing import Any, Callable +from unittest.mock import patch + +import pytest +import torch +from datasets import Dataset +from torch import Tensor +from transformer_lens import HookedTransformer + +from sae_training.activations_store import ActivationsStore +from sae_training.optim import get_scheduler +from sae_training.sae_group import SAEGroup +from sae_training.sparse_autoencoder import ForwardOutput, SparseAutoencoder +from sae_training.train_sae_on_language_model import ( + SAETrainContext, + TrainStepOutput, + _build_train_step_log_dict, + _log_feature_sparsity, + _save_checkpoint, + _train_step, + train_sae_group_on_language_model, +) +from tests.unit.helpers import build_sae_cfg + + +def build_train_ctx( + sae: SparseAutoencoder, + act_freq_scores: Tensor | None = None, + n_forward_passes_since_fired: Tensor | None = None, + n_frac_active_tokens: int = 0, +) -> SAETrainContext: + """ + Factory helper to build a default SAETrainContext object. + """ + assert sae.cfg.d_sae is not None + optimizer = torch.optim.Adam(sae.parameters(), lr=sae.cfg.lr) + return SAETrainContext( + act_freq_scores=( + torch.zeros(sae.cfg.d_sae) if act_freq_scores is None else act_freq_scores + ), + n_forward_passes_since_fired=( + torch.zeros(sae.cfg.d_sae) + if n_forward_passes_since_fired is None + else n_forward_passes_since_fired + ), + n_frac_active_tokens=n_frac_active_tokens, + optimizer=optimizer, + scheduler=get_scheduler(None, optimizer=optimizer), + ) + + +def modify_sae_output( + sae: SparseAutoencoder, modifier: Callable[[ForwardOutput], ForwardOutput] +): + """ + Helper to modify the output of the SAE forward pass for use in patching, for use in patch side_effect. + We need real grads during training, so we can't just mock the whole forward pass directly. + """ + + def modified_forward(*args: Any, **kwargs: Any): + output = SparseAutoencoder.forward(sae, *args, **kwargs) + return modifier(output) + + return modified_forward + + +def test_train_step__reduces_loss_when_called_repeatedly_on_same_acts() -> None: + cfg = build_sae_cfg(d_in=64, d_sae=128, hook_point_layer=0) + sae = SparseAutoencoder(cfg) + ctx = build_train_ctx(sae) + + layer_acts = torch.randn(10, 1, 64) + + # intentionally train on the same activations 5 times to ensure loss decreases + train_outputs = [ + _train_step( + sparse_autoencoder=sae, + ctx=ctx, + layer_acts=layer_acts, + all_layers=[0], + feature_sampling_window=1000, + use_wandb=False, + n_training_steps=10, + batch_size=10, + wandb_suffix="", + ) + for _ in range(5) + ] + + # ensure loss decreases with each training step + for output, next_output in zip(train_outputs[:-1], train_outputs[1:]): + assert output.loss > next_output.loss + assert ctx.n_frac_active_tokens == 50 # should increment each step by batch_size + + +def test_train_step__output_looks_reasonable() -> None: + cfg = build_sae_cfg(d_in=64, d_sae=128, hook_point_layer=0) + sae = SparseAutoencoder(cfg) + ctx = build_train_ctx(sae) + + layer_acts = torch.randn(10, 2, 64) + + output = _train_step( + sparse_autoencoder=sae, + ctx=ctx, + layer_acts=layer_acts, + all_layers=[0], + feature_sampling_window=1000, + use_wandb=False, + n_training_steps=10, + batch_size=10, + wandb_suffix="", + ) + + assert output.loss > 0 + # only hook_point_layer=0 acts should be passed to the SAE + assert torch.allclose(output.sae_in, layer_acts[:, 0, :]) + assert output.sae_out.shape == output.sae_in.shape + assert output.feature_acts.shape == (10, 128) # batch_size, d_sae + assert output.ghost_grad_neuron_mask.shape == (128,) + assert output.loss.shape == () + assert output.mse_loss.shape == () + assert output.ghost_grad_loss.shape == () + # ghots grads shouldn't trigger until dead_feature_window, which hasn't been reached yet + assert torch.all(output.ghost_grad_neuron_mask == False) # noqa + assert output.ghost_grad_loss == 0 + assert ctx.n_frac_active_tokens == 10 + assert ctx.act_freq_scores.sum() > 0 # at least SOME acts should have fired + assert torch.allclose( + ctx.act_freq_scores, (output.feature_acts.abs() > 0).float().sum(0) + ) + + +def test_train_step__ghost_grads_mask() -> None: + cfg = build_sae_cfg(d_in=2, d_sae=4, dead_feature_window=5) + sae = SparseAutoencoder(cfg) + ctx = build_train_ctx( + sae, n_forward_passes_since_fired=torch.tensor([0, 4, 7, 9]).float() + ) + + output = _train_step( + sparse_autoencoder=sae, + ctx=ctx, + layer_acts=torch.randn(10, 1, 2), + all_layers=[0], + feature_sampling_window=1000, + use_wandb=False, + n_training_steps=10, + batch_size=10, + wandb_suffix="", + ) + assert torch.all( + output.ghost_grad_neuron_mask == torch.Tensor([False, False, True, True]) + ) + + +def test_train_step__sparsity_updates_based_on_feature_act_sparsity() -> None: + cfg = build_sae_cfg(d_in=2, d_sae=4, hook_point_layer=0) + sae = SparseAutoencoder(cfg) + + feature_acts = torch.tensor([[0, 0, 0, 0], [1, 0, 0, 1], [1, 0, 1, 1]]).float() + layer_acts = torch.randn(3, 1, 2) + + ctx = build_train_ctx( + sae, + n_frac_active_tokens=9, + act_freq_scores=torch.tensor([0, 3, 7, 1]).float(), + n_forward_passes_since_fired=torch.tensor([8, 2, 0, 0]).float(), + ) + with patch.object( + sae, + "forward", + side_effect=modify_sae_output( + sae, lambda out: out._replace(feature_acts=feature_acts) + ), + ): + train_output = _train_step( + sparse_autoencoder=sae, + ctx=ctx, + layer_acts=layer_acts, + all_layers=[0], + feature_sampling_window=1000, + use_wandb=False, + n_training_steps=10, + batch_size=3, + wandb_suffix="", + ) + + # should increase by batch_size + assert ctx.n_frac_active_tokens == 12 + # add freq scores for all non-zero feature acts + assert torch.allclose( + ctx.act_freq_scores, + torch.tensor([2, 3, 8, 3]).float(), + ) + assert torch.allclose( + ctx.n_forward_passes_since_fired, + torch.tensor([0, 3, 0, 0]).float(), + ) + + # the outputs from the SAE should be included in the train output + assert train_output.feature_acts is feature_acts + + +def test_log_feature_sparsity__handles_zeroes_by_default_fp32() -> None: + fp32_zeroes = torch.tensor([0], dtype=torch.float32) + assert _log_feature_sparsity(fp32_zeroes).item() != float("-inf") + + +# TODO: currently doesn't work for fp16, we should address this +@pytest.mark.skip(reason="Currently doesn't work for fp16") +def test_log_feature_sparsity__handles_zeroes_by_default_fp16() -> None: + fp16_zeroes = torch.tensor([0], dtype=torch.float16) + assert _log_feature_sparsity(fp16_zeroes).item() != float("-inf") + + +def test_build_train_step_log_dict() -> None: + cfg = build_sae_cfg( + d_in=2, d_sae=4, hook_point_layer=0, lr=2e-4, l1_coefficient=1e-2 + ) + sae = SparseAutoencoder(cfg) + ctx = build_train_ctx( + sae, + act_freq_scores=torch.tensor([0, 3, 1, 0]).float(), + n_frac_active_tokens=10, + n_forward_passes_since_fired=torch.tensor([4, 0, 0, 0]).float(), + ) + train_output = TrainStepOutput( + sae_in=torch.tensor([[-1, 0], [0, 2], [1, 1]]).float(), + sae_out=torch.tensor([[0, 0], [0, 2], [0.5, 1]]).float(), + feature_acts=torch.tensor([[0, 0, 0, 1], [1, 0, 0, 1], [1, 0, 1, 1]]).float(), + loss=torch.tensor(0.5), + mse_loss=torch.tensor(0.25), + l1_loss=torch.tensor(0.1), + ghost_grad_loss=torch.tensor(0.15), + ghost_grad_neuron_mask=torch.tensor([False, True, False, True]), + ) + + log_dict = _build_train_step_log_dict( + sae, train_output, ctx, wandb_suffix="-wandbftw", n_training_tokens=123 + ) + assert log_dict == { + "losses/mse_loss-wandbftw": 0.25, + # l1 loss is scaled by l1_coefficient + "losses/l1_loss-wandbftw": pytest.approx(10), + "losses/ghost_grad_loss-wandbftw": pytest.approx(0.15), + "losses/overall_loss-wandbftw": 0.5, + "metrics/explained_variance-wandbftw": 0.75, + "metrics/explained_variance_std-wandbftw": 0.25, + "metrics/l0-wandbftw": 2.0, + "sparsity/mean_passes_since_fired-wandbftw": 1.0, + "sparsity/dead_features-wandbftw": 2, + "details/current_learning_rate-wandbftw": 2e-4, + "details/n_training_tokens": 123, + } + + +def test_save_checkpoint(tmp_path: Path) -> None: + checkpoint_dir = tmp_path / "checkpoint" + cfg = build_sae_cfg(checkpoint_path=checkpoint_dir, d_in=25, d_sae=100) + sae_group = SAEGroup(cfg) + assert len(sae_group.autoencoders) == 1 + ctx = build_train_ctx( + sae_group.autoencoders[0], + act_freq_scores=torch.randint(0, 100, (100,)), + n_forward_passes_since_fired=torch.randint(0, 100, (100,)), + n_frac_active_tokens=123, + ) + res = _save_checkpoint(sae_group, [ctx], "test_checkpoint") + assert res.path == str( + checkpoint_dir / f"test_checkpoint_{sae_group.get_name()}.pt" + ) + assert res.log_feature_sparsity_path == str( + checkpoint_dir + / f"test_checkpoint_{sae_group.get_name()}_log_feature_sparsity.pt" + ) + assert torch.allclose( + res.log_feature_sparsities[0], _log_feature_sparsity(ctx.feature_sparsity) + ) + + # now, load the saved checkpoints to make sure they match what we saved + loaded_sae_group = torch.load(res.path) + loaded_log_sparsities = torch.load(res.log_feature_sparsity_path) + + assert isinstance(loaded_sae_group, SAEGroup) + assert len(loaded_sae_group.autoencoders) == 1 + assert loaded_sae_group.get_name() == sae_group.get_name() + + loaded_state_dict = loaded_sae_group.autoencoders[0].state_dict() + original_state_dict = sae_group.autoencoders[0].state_dict() + + assert list(loaded_state_dict.keys()) == list(original_state_dict.keys()) + for orig_val, loaded_val in zip( + original_state_dict.values(), loaded_state_dict.values() + ): + assert torch.allclose(orig_val, loaded_val) + + assert torch.allclose( + loaded_log_sparsities[0], _log_feature_sparsity(ctx.feature_sparsity) + ) + + +def test_train_sae_group_on_language_model__runs_and_outputs_look_reasonable( + ts_model: HookedTransformer, + tmp_path: Path, +) -> None: + checkpoint_dir = tmp_path / "checkpoint" + cfg = build_sae_cfg( + checkpoint_path=checkpoint_dir, + train_batch_size=32, + total_training_tokens=100, + context_size=8, + ) + # just a tiny datast which will run quickly + dataset = Dataset.from_list([{"text": "hello world"}] * 1000) + activation_store = ActivationsStore(cfg, model=ts_model, dataset=dataset) + sae_group = SAEGroup(cfg) + res = train_sae_group_on_language_model( + model=ts_model, + sae_group=sae_group, + activation_store=activation_store, + batch_size=32, + ) + assert res.checkpoint_paths == [ + str(checkpoint_dir / f"final_{sae_group.get_name()}.pt") + ] + assert len(res.log_feature_sparsities) == 1 + assert res.log_feature_sparsities[0].shape == (cfg.d_sae,) + assert res.sae_group is sae_group