diff --git a/activation_storing.py b/activation_storing.py index d29cc594..92f535c4 100644 --- a/activation_storing.py +++ b/activation_storing.py @@ -26,7 +26,7 @@ # Activation Store Parameters n_batches_in_buffer = 16, - total_training_tokens = 300_000_000, + total_training_tokens = 10_000_000, store_batch_size = 64, # Activation caching shuffle parameters diff --git a/lp_sae_training.py b/lp_sae_training.py index ac6d6884..0bc93ec5 100644 --- a/lp_sae_training.py +++ b/lp_sae_training.py @@ -12,8 +12,8 @@ # Data Generating Function (Model + Training Distibuion) model_name = "gpt2-small", - hook_point = "blocks.2.hook_resid_pre", - hook_point_layer = 2, + hook_point = "blocks.{layer}.hook_resid_pre", + hook_point_layer = 6, d_in = 768, dataset_path = "Skylion007/openwebtext", is_dataset_tokenized=False, @@ -24,7 +24,27 @@ # Training Parameters lr = 4e-4, - l1_coefficient = 8e-5, + l1_coefficient = [ #8e-5 + # 1e-9, + # 1e-8, + 1e-7, + 1e-6, + 1e-5, + 1e-4, + ], + lp_norm = [ + # 0.1, + # 0.2, + # 0.3, + # 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1, + # 1.1, + ], lr_scheduler_name="constantwithwarmup", train_batch_size = 4096, context_size = 128, @@ -32,7 +52,7 @@ # Activation Store Parameters n_batches_in_buffer = 128, - total_training_tokens = 1_000_000 * 300, + total_training_tokens = 300_000_000, store_batch_size = 32, # Dead Neurons and Sparsity @@ -50,10 +70,10 @@ # Misc device = "cuda", seed = 42, - n_checkpoints = 10, + n_checkpoints = 2, checkpoint_path = "checkpoints", dtype = torch.float32, - use_cached_activations = True, + use_cached_activations = False, ) sparse_autoencoder = language_model_sae_runner(cfg) \ No newline at end of file diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 05751f4c..ed0d02f8 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -148,7 +148,7 @@ def get_activations(self, batch_tokens, get_loss=False): """ layers = self.cfg.hook_point_layer if isinstance(self.cfg.hook_point_layer, list) else [self.cfg.hook_point_layer] act_names = [self.cfg.hook_point.format(layer = layer) for layer in layers] - hook_point_max_layer = self.cfg.hook_point_layer + hook_point_max_layer = max(layers) if self.cfg.hook_point_head_index is not None: layerwise_activations = self.model.run_with_cache( batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1 @@ -176,7 +176,8 @@ def get_buffer(self, n_batches_in_buffer): batch_size = self.cfg.store_batch_size d_in = self.cfg.d_in total_size = batch_size * n_batches_in_buffer - num_layers = len(self.cfg.hook_points) # Number of hook points or layers + num_layers = len(self.cfg.hook_point_layer) if isinstance(self.cfg.hook_point_layer, list) \ + else 1 # Number of hook points or layers if self.cfg.use_cached_activations: # Load the activations from disk @@ -268,7 +269,7 @@ def get_data_loader( dim=0, ) - mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[1])] + mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])] # 2. put 50 % in storage self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2] diff --git a/sae_training/sae_group.py b/sae_training/sae_group.py index c395b1a2..6c466abd 100644 --- a/sae_training/sae_group.py +++ b/sae_training/sae_group.py @@ -16,22 +16,24 @@ def _init_autoencoders(self, cfg): # Extract all hyperparameter lists from cfg hyperparameters = {k: v for k, v in vars(cfg).items() if isinstance(v, list)} - keys, values = zip(*hyperparameters.items()) + if len(hyperparameters) > 0: + keys, values = zip(*hyperparameters.items()) + else: + keys, values = (), ([()],) # Ensure product(*values) yields one combination # Create all combinations of hyperparameters for combination in product(*values): - cfg_copy = dataclasses.replace(cfg) params = dict(zip(keys, combination)) - cfg_copy.update(params) + cfg_copy = dataclasses.replace(cfg, **params) # Insert the layer into the hookpoint - cfg_copy.hook_point = cfg_copy.hook_point.format(layer=cfg.copy.hook_point_layer) + cfg_copy.hook_point = cfg_copy.hook_point.format(layer=cfg_copy.hook_point_layer) # Create and store both the SparseAutoencoder instance and its parameters - self.autoencoders.append((SparseAutoencoder(cfg_copy), cfg_copy)) + self.autoencoders.append(SparseAutoencoder(cfg_copy)) def __iter__(self): # Make SAEGroup iterable over its SparseAutoencoder instances and their parameters - for ae, params in self.autoencoders: - yield ae, params # Yielding as a tuple + for ae in self.autoencoders: + yield ae # Yielding as a tuple def __len__(self): # Return the number of SparseAutoencoder instances diff --git a/sae_training/sparse_autoencoder.py b/sae_training/sparse_autoencoder.py index 1c889bc6..fb321027 100644 --- a/sae_training/sparse_autoencoder.py +++ b/sae_training/sparse_autoencoder.py @@ -138,13 +138,18 @@ def forward(self, x, dead_neuron_mask=None): loss = mse_loss + l1_loss + mse_loss_ghost_resid return sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid + + @torch.no_grad() + def initialize_b_dec_with_precalculated(self, origin): + out = torch.tensor(origin, dtype=self.dtype, device=self.device) + self.b_dec.data = out @torch.no_grad() - def initialize_b_dec(self, activation_store): + def initialize_b_dec(self, all_activations): if self.cfg.b_dec_init_method == "geometric_median": - self.initialize_b_dec_with_geometric_median(activation_store) + self.initialize_b_dec_with_geometric_median(all_activations) elif self.cfg.b_dec_init_method == "mean": - self.initialize_b_dec_with_mean(activation_store) + self.initialize_b_dec_with_mean(all_activations) elif self.cfg.b_dec_init_method == "zeros": pass else: @@ -153,9 +158,8 @@ def initialize_b_dec(self, activation_store): ) @torch.no_grad() - def initialize_b_dec_with_geometric_median(self, activation_store): + def initialize_b_dec_with_geometric_median(self, all_activations): previous_b_dec = self.b_dec.clone().cpu() - all_activations = activation_store.storage_buffer.detach().cpu() out = compute_geometric_median( all_activations, skip_typechecks=True, maxiter=100, per_component=False ).median @@ -173,9 +177,8 @@ def initialize_b_dec_with_geometric_median(self, activation_store): self.b_dec.data = out @torch.no_grad() - def initialize_b_dec_with_mean(self, activation_store): + def initialize_b_dec_with_mean(self, all_activations): previous_b_dec = self.b_dec.clone().cpu() - all_activations = activation_store.storage_buffer.detach().cpu() out = all_activations.mean(dim=0) previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1) diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index b31cbc0d..610f952b 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -9,6 +9,7 @@ from sae_training.optim import get_scheduler from sae_training.sparse_autoencoder import SparseAutoencoder from sae_training.sae_group import SAEGroup +from sae_training.geom_median.src.geom_median.torch import compute_geometric_median def train_sae_on_language_model( @@ -37,6 +38,7 @@ def train_sae_on_language_model( # act_freq_scores, n_forward_passes_since_fired, n_frac_active_tokens, optimizer, scheduler, num_saes = len(sparse_autoencoders) # track active features + act_freq_scores = [ torch.zeros( sparse_autoencoders.cfg.d_sae, device=sparse_autoencoders.cfg.device @@ -49,61 +51,86 @@ def train_sae_on_language_model( ] n_frac_active_tokens = [0 for _ in range(num_saes)] - optimizers = [ - Adam(sae.parameters(), lr=sparse_autoencoder.cfg.lr) - for sae, hyperparams in sparse_autoencoders + optimizer = [ + Adam(sae.parameters(), lr=sae.cfg.lr) + for sae in sparse_autoencoders ] - schedulers = [ + scheduler = [ get_scheduler( - sparse_autoencoder.cfg.lr_scheduler_name, - optimizer=optimizer, - warm_up_steps=sparse_autoencoder.cfg.lr_warm_up_steps, + sae.cfg.lr_scheduler_name, + optimizer=opt, + warm_up_steps=sae.cfg.lr_warm_up_steps, training_steps=total_training_steps, - lr_end=sparse_autoencoder.cfg.lr / 10, # heuristic for now. - ) for optimizer in optimizers + lr_end=sae.cfg.lr / 10, # heuristic for now. + ) for sae, opt in zip(sparse_autoencoders, optimizer) ] - for sae, hyperparams in sparse_autoencoders: - sae.initialize_b_dec(activation_store) + all_layers = sparse_autoencoders.cfg.hook_point_layer + if not isinstance(all_layers, list): + all_layers = [all_layers] + + # compute the geometric median of the activations of each layer + + geometric_medians = [] + for layer_id in range(len(all_layers)): + layer_acts = activation_store.storage_buffer.detach().cpu()[:,layer_id,:] + + median = compute_geometric_median( + layer_acts, skip_typechecks=True, maxiter=5, per_component=False + ).median + geometric_medians.append(median) + + for sae in sparse_autoencoders: + hyperparams = sae.cfg + sae_layer_id = all_layers.index(hyperparams.hook_point_layer) + + # extract all activations at a certain layer and use for sae initialization + sae.initialize_b_dec_with_precalculated(geometric_medians[sae_layer_id]) sae.train() pbar = tqdm(total=total_training_tokens, desc="Training SAE") while n_training_tokens < total_training_tokens: # Do a training step. - sae_in = activation_store.next_batch() + layer_acts = activation_store.next_batch() + n_training_tokens += batch_size - for (sparse_autoencoder, hyperparams), in zip(sparse_autoencoders, ): + for i, (sparse_autoencoder), in enumerate(sparse_autoencoders): + hyperparams = sae.cfg + layer_id = all_layers.index(hyperparams.hook_point_layer) + sae_in = layer_acts[:,layer_id,:] + sparse_autoencoder.train() # Make sure the W_dec is still zero-norm sparse_autoencoder.set_decoder_norm_to_unit_norm() # log and then reset the feature sparsity every feature_sampling_window steps if (n_training_steps + 1) % feature_sampling_window == 0: - feature_sparsity = act_freq_scores / n_frac_active_tokens + feature_sparsity = act_freq_scores[i] / n_frac_active_tokens[i] log_feature_sparsity = torch.log10(feature_sparsity + 1e-10).detach().cpu() if use_wandb: + suffix = wandb_log_suffix(sparse_autoencoders.cfg, hyperparams) wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) wandb.log( { - "metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(), - "plots/feature_density_line_chart": wandb_histogram, - "sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(), - "sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(), + f"metrics/mean_log10_feature_sparsity{suffix}": log_feature_sparsity.mean().item(), + f"plots/feature_density_line_chart{suffix}": wandb_histogram, + f"sparsity/below_1e-5{suffix}": (feature_sparsity < 1e-5).sum().item(), + f"sparsity/below_1e-6{suffix}": (feature_sparsity < 1e-6).sum().item(), }, step=n_training_steps, ) - act_freq_scores = torch.zeros( + act_freq_scores[i] = torch.zeros( sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device ) - n_frac_active_tokens = 0 + n_frac_active_tokens[i] = 0 - scheduler.step() - optimizer.zero_grad() + scheduler[i].step() + optimizer[i].zero_grad() ghost_grad_neuron_mask = ( - n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window + n_forward_passes_since_fired[i] > sparse_autoencoder.cfg.dead_feature_window ).bool() @@ -120,43 +147,42 @@ def train_sae_on_language_model( ghost_grad_neuron_mask, ) did_fire = (feature_acts > 0).float().sum(-2) > 0 - n_forward_passes_since_fired += 1 - n_forward_passes_since_fired[did_fire] = 0 - - n_training_tokens += batch_size + n_forward_passes_since_fired[i] += 1 + n_forward_passes_since_fired[i][did_fire] = 0 with torch.no_grad(): # Calculate the sparsities, and add it to a list, calculate sparsity metrics - act_freq_scores += (feature_acts.abs() > 0).float().sum(0) - n_frac_active_tokens += batch_size - feature_sparsity = act_freq_scores / n_frac_active_tokens + act_freq_scores[i] += (feature_acts.abs() > 0).float().sum(0) + n_frac_active_tokens[i] += batch_size + feature_sparsity = act_freq_scores[i] / n_frac_active_tokens[i] if use_wandb and ((n_training_steps + 1) % wandb_log_frequency == 0): # metrics for currents acts l0 = (feature_acts > 0).float().sum(-1).mean() - current_learning_rate = optimizer.param_groups[0]["lr"] + current_learning_rate = optimizer[i].param_groups[0]["lr"] per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze() total_variance = (sae_in - sae_in.mean(0)).pow(2).sum(-1) explained_variance = 1 - per_token_l2_loss / total_variance + suffix = wandb_log_suffix(sparse_autoencoders.cfg, hyperparams) wandb.log( { # losses - "losses/mse_loss": mse_loss.item(), - "losses/l1_loss": l1_loss.item() + f"losses/mse_loss{suffix}": mse_loss.item(), + f"losses/l1_loss{suffix}": l1_loss.item() / sparse_autoencoder.l1_coefficient, # normalize by l1 coefficient - "losses/ghost_grad_loss": ghost_grad_loss.item(), - "losses/overall_loss": loss.item(), + f"losses/ghost_grad_loss{suffix}": ghost_grad_loss.item(), + f"losses/overall_loss{suffix}": loss.item(), # variance explained - "metrics/explained_variance": explained_variance.mean().item(), - "metrics/explained_variance_std": explained_variance.std().item(), - "metrics/l0": l0.item(), + f"metrics/explained_variance{suffix}": explained_variance.mean().item(), + f"metrics/explained_variance_std{suffix}": explained_variance.std().item(), + f"metrics/l0{suffix}": l0.item(), # sparsity - "sparsity/mean_passes_since_fired": n_forward_passes_since_fired.mean().item(), - "sparsity/dead_features": ghost_grad_neuron_mask.sum().item(), - "details/n_training_tokens": n_training_tokens, - "details/current_learning_rate": current_learning_rate, + f"sparsity/mean_passes_since_fired{suffix}": n_forward_passes_since_fired[i].mean().item(), + f"sparsity/dead_features{suffix}": ghost_grad_neuron_mask.sum().item(), + f"details/n_training_tokens{suffix}": n_training_tokens, + f"details/current_learning_rate{suffix}": current_learning_rate, }, step=n_training_steps, ) @@ -167,14 +193,9 @@ def train_sae_on_language_model( run_evals(sparse_autoencoder, activation_store, model, n_training_steps) sparse_autoencoder.train() - pbar.set_description( - f"{n_training_steps}| MSE Loss {mse_loss.item():.3f} | L1 {l1_loss.item():.3f}" - ) - pbar.update(batch_size) - loss.backward() sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() - optimizer.step() + optimizer[i].step() # checkpoint if at checkpoint frequency if n_checkpoints > 0 and n_training_tokens > checkpoint_thresholds[0]: @@ -204,7 +225,12 @@ def train_sae_on_language_model( sparsity_artifact.add_file(log_feature_sparsity_path) wandb.log_artifact(sparsity_artifact) - n_training_steps += 1 + n_training_steps += 1 + pbar.set_description( + f"{n_training_steps}| MSE Loss {mse_loss.item():.3f} | L1 {l1_loss.item():.3f}" + ) + pbar.update(batch_size) + # save sae to checkpoints folder path = f"{sparse_autoencoder.cfg.checkpoint_path}/final_{sparse_autoencoder.get_name()}.pt" @@ -232,3 +258,17 @@ def train_sae_on_language_model( wandb.log_artifact(sparsity_artifact) return sparse_autoencoder + +def wandb_log_suffix(cfg, hyperparams): + # Create a mapping from cfg list keys to their corresponding hyperparams attributes + key_mapping = { + "hook_point_layer": "layer", + "l1_coefficient": "coeff", + "lp_norm": "l", + "lr": "lr" + } + + # Generate the suffix by iterating over the keys that have list values in cfg + suffix = "".join(f"_{key_mapping.get(key, key)}{getattr(hyperparams, key, '')}" + for key, value in vars(cfg).items() if isinstance(value, list)) + return suffix \ No newline at end of file diff --git a/sae_training/utils.py b/sae_training/utils.py index c2633a56..b25d91cb 100644 --- a/sae_training/utils.py +++ b/sae_training/utils.py @@ -6,6 +6,7 @@ from sae_training.activations_store import ActivationsStore from sae_training.config import LanguageModelSAERunnerConfig from sae_training.sparse_autoencoder import SparseAutoencoder +from sae_training.sae_group import SAEGroup class LMSparseAutoencoderSessionloader: @@ -21,7 +22,7 @@ def __init__(self, cfg: LanguageModelSAERunnerConfig): def load_session( self, - ) -> Tuple[HookedTransformer, SparseAutoencoder, ActivationsStore]: + ) -> Tuple[HookedTransformer, SAEGroup, ActivationsStore]: """ Loads a session for training a sparse autoencoder on a language model. """ @@ -66,10 +67,10 @@ def get_model(self, model_name: str): def initialize_sparse_autoencoder(self, cfg: LanguageModelSAERunnerConfig): """ - Initializes a sparse autoencoder + Initializes a sparse autoencoder group, which contains multiple sparse autoencoders """ - sparse_autoencoder = SparseAutoencoder(cfg) + sparse_autoencoder = SAEGroup(cfg) return sparse_autoencoder