From ade29762b4b94e02c3b44a66bedf150e721aec7c Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Wed, 7 Feb 2024 18:42:35 -0800 Subject: [PATCH 1/9] get units tests working --- requirements.txt | 1 + sae_training/config.py | 17 +- sae_training/evals.py | 157 ++++++++++ sae_training/lm_runner.py | 3 - sae_training/sparse_autoencoder.py | 242 +--------------- sae_training/timeit.py | 25 -- sae_training/tmp.py | 0 sae_training/toy_model_runner.py | 10 +- sae_training/toy_models.py | 17 +- sae_training/train_sae_on_language_model.py | 271 +----------------- sae_training/train_sae_on_toy_model.py | 84 ++---- .../test_language_model_sae_runner.py | 184 ++---------- tests/benchmark/test_toy_model_sae_runner.py | 23 +- tests/unit/test_sparse_autoencoder.py | 5 +- 14 files changed, 253 insertions(+), 786 deletions(-) create mode 100644 sae_training/evals.py delete mode 100644 sae_training/timeit.py delete mode 100644 sae_training/tmp.py diff --git a/requirements.txt b/requirements.txt index b33d2517..7737a85c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ pylint==3.0.2 black==23.11.0 pytest==7.4.3 pytest-cov==4.1.0 +pre-commit==3.6.0 git+https://github.com/callummcdougall/eindex.git \ No newline at end of file diff --git a/sae_training/config.py b/sae_training/config.py index 8fde2c9b..4ceb0a1b 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -61,18 +61,15 @@ class LanguageModelSAERunnerConfig(RunnerConfig): # Training Parameters l1_coefficient: float = 1e-3 lr: float = 3e-4 - lr_scheduler_name: str = "constant" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup + lr_scheduler_name: str = "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup lr_warm_up_steps: int = 500 train_batch_size: int = 4096 # Resampling protocol args use_ghost_grads: bool = False # want to change this to true on some timeline. feature_sampling_window: int = 2000 - feature_sampling_method: str = "Anthropic" # None or Anthropic - resample_batches: int = 32 - feature_reinit_scale: float = 0.2 dead_feature_window: int = 1000 # unless this window is larger feature sampling, - dead_feature_estimation_method: str = "no_fire" + dead_feature_threshold: float = 1e-8 # WANDB @@ -94,11 +91,6 @@ def __post_init__(self): self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" - if self.feature_sampling_method not in [None, "l2", "anthropic"]: - raise ValueError( - f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}" - ) - if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]: raise ValueError( f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}" @@ -134,7 +126,6 @@ def __post_init__(self): # how many times will we sample dead neurons? # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window" - n_dead_feature_samples = total_training_steps // self.dead_feature_window n_feature_window_samples = total_training_steps // self.feature_sampling_window print( f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}" @@ -142,8 +133,7 @@ def __post_init__(self): print( f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}" ) - if self.feature_sampling_method != None: - print(f"We will reset neurons {n_dead_feature_samples} times.") + if self.use_ghost_grads: print("Using Ghost Grads.") @@ -151,7 +141,6 @@ def __post_init__(self): print( f"We will reset the sparsity calculation {n_feature_window_samples} times." ) - print(f"Number of tokens when resampling: {self.resample_batches * self.store_batch_size}") # print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size) print( f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}" diff --git a/sae_training/evals.py b/sae_training/evals.py new file mode 100644 index 00000000..2d22138f --- /dev/null +++ b/sae_training/evals.py @@ -0,0 +1,157 @@ +from functools import partial + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.utils import get_act_name + +import wandb +from sae_training.activations_store import ActivationsStore +from sae_training.sparse_autoencoder import SparseAutoencoder + + +@torch.no_grad() +def run_evals(sparse_autoencoder: SparseAutoencoder, activation_store: ActivationsStore, model: HookedTransformer, n_training_steps: int): + + hook_point = sparse_autoencoder.cfg.hook_point + hook_point_layer = sparse_autoencoder.cfg.hook_point_layer + hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index + + ### Evals + eval_tokens = activation_store.get_batch_tokens() + + # Get Reconstruction Score + recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(sparse_autoencoder, model, activation_store, eval_tokens) + + # get cache + _, cache = model.run_with_cache(eval_tokens, prepend_bos=False, names_filter=[get_act_name("pattern", hook_point_layer), hook_point]) + + # get act + if sparse_autoencoder.cfg.hook_point_head_index is not None: + original_act = cache[sparse_autoencoder.cfg.hook_point][:,:,sparse_autoencoder.cfg.hook_point_head_index] + else: + original_act = cache[sparse_autoencoder.cfg.hook_point] + + sae_out, feature_acts, _, _, _, _ = sparse_autoencoder( + original_act + ) + patterns_original = cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu() + del cache + + if "cuda" in str(model.cfg.device): + torch.cuda.empty_cache() + + l2_norm_in = torch.norm(original_act, dim=-1) + l2_norm_out = torch.norm(sae_out, dim=-1) + l2_norm_ratio = l2_norm_out / l2_norm_in + + wandb.log( + { + + # l2 norms + "metrics/l2_norm": l2_norm_out.mean().item(), + "metrics/l2_ratio": l2_norm_ratio.mean().item(), + + # CE Loss + "metrics/CE_loss_score": recons_score, + "metrics/ce_loss_without_sae": ntp_loss, + "metrics/ce_loss_with_sae": recons_loss, + "metrics/ce_loss_with_ablation": zero_abl_loss, + + }, + step=n_training_steps, + ) + + head_index = sparse_autoencoder.cfg.hook_point_head_index + + def standard_replacement_hook(activations, hook): + activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype) + return activations + + def head_replacement_hook(activations, hook): + new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype) + activations[:,:,head_index] = new_actions + return activations + + head_index = sparse_autoencoder.cfg.hook_point_head_index + replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook + + # get attn when using reconstructed activations + with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]): + _, new_cache = model.run_with_cache(eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]) + patterns_reconstructed = new_cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu() + del new_cache + + # get attn when using reconstructed activations + with model.hooks(fwd_hooks=[(hook_point, partial(zero_ablate_hook))]): + _, zero_ablation_cache = model.run_with_cache(eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]) + patterns_ablation = zero_ablation_cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu() + del zero_ablation_cache + + if sparse_autoencoder.cfg.hook_point_head_index: + + kl_result_reconstructed = kl_divergence_attention(patterns_original, patterns_reconstructed) + kl_result_reconstructed = kl_result_reconstructed.sum(dim=-1).numpy() + + + kl_result_ablation = kl_divergence_attention(patterns_original, patterns_ablation) + kl_result_ablation = kl_result_ablation.sum(dim=-1).numpy() + + wandb.log( + { + + "metrics/kldiv_reconstructed": kl_result_reconstructed.mean().item(), + "metrics/kldiv_ablation": kl_result_ablation.mean().item(), + + }, + step=n_training_steps, + ) + +@torch.no_grad() +def get_recons_loss(sparse_autoencoder, model, activation_store, batch_tokens): + hook_point = activation_store.cfg.hook_point + loss = model(batch_tokens, return_type="loss") + + head_index = sparse_autoencoder.cfg.hook_point_head_index + + def standard_replacement_hook(activations, hook): + activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype) + return activations + + def head_replacement_hook(activations, hook): + new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype) + activations[:,:,head_index] = new_actions + return activations + + replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook + recons_loss = model.run_with_hooks( + batch_tokens, + return_type="loss", + fwd_hooks=[(hook_point, partial(replacement_hook))], + ) + + zero_abl_loss = model.run_with_hooks( + batch_tokens, return_type="loss", fwd_hooks=[(hook_point, zero_ablate_hook)] + ) + + score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss) + + return score, loss, recons_loss, zero_abl_loss + + +def mean_ablate_hook(mlp_post, hook): + mlp_post[:] = mlp_post.mean([0, 1]).to(mlp_post.dtype) + return mlp_post + + +def zero_ablate_hook(mlp_post, hook): + mlp_post[:] = 0.0 + return mlp_post + + +def kl_divergence_attention(y_true, y_pred): + + # Compute log probabilities for KL divergence + log_y_true = torch.log2(y_true + 1e-10) + log_y_pred = torch.log2(y_pred + 1e-10) + + return y_true * (log_y_true - log_y_pred) \ No newline at end of file diff --git a/sae_training/lm_runner.py b/sae_training/lm_runner.py index 06f92567..d1609bb6 100644 --- a/sae_training/lm_runner.py +++ b/sae_training/lm_runner.py @@ -30,11 +30,8 @@ def language_model_sae_runner(cfg): model, sparse_autoencoder, activations_loader, n_checkpoints=cfg.n_checkpoints, batch_size = cfg.train_batch_size, - feature_sampling_method = cfg.feature_sampling_method, feature_sampling_window = cfg.feature_sampling_window, - feature_reinit_scale = cfg.feature_reinit_scale, dead_feature_threshold = cfg.dead_feature_threshold, - dead_feature_window=cfg.dead_feature_window, use_wandb = cfg.log_to_wandb, wandb_log_frequency = cfg.wandb_log_frequency ) diff --git a/sae_training/sparse_autoencoder.py b/sae_training/sparse_autoencoder.py index ded561ae..967ffe2a 100644 --- a/sae_training/sparse_autoencoder.py +++ b/sae_training/sparse_autoencoder.py @@ -76,7 +76,7 @@ def forward(self, x, dead_neuron_mask = None): x = x.to(self.dtype) sae_in = self.hook_sae_in( x - self.b_dec - ) # Remove encoder bias as per Anthropic + ) # Remove decoder bias as per Anthropic hidden_pre = self.hook_hidden_pre( einops.einsum( @@ -98,7 +98,8 @@ def forward(self, x, dead_neuron_mask = None): ) # add config for whether l2 is normalized: - mse_loss = (torch.pow((sae_out-x.float()), 2) / (x**2).sum(dim=-1, keepdim=True).sqrt()) + x_centred = x - x.mean(dim=0, keepdim=True) + mse_loss = (torch.pow((sae_out-x.float()), 2) / (x_centred**2).sum(dim=-1, keepdim=True).sqrt()) mse_loss_ghost_resid = torch.tensor(0.0, dtype=self.dtype, device=self.device) # gate on config and training so evals is not slowed down. @@ -125,7 +126,9 @@ def forward(self, x, dead_neuron_mask = None): mse_rescaling_factor = (mse_loss / (mse_loss_ghost_resid + 1e-6)).detach() mse_loss_ghost_resid = mse_rescaling_factor * mse_loss_ghost_resid - mse_loss_ghost_resid = mse_loss_ghost_resid.mean() + mse_loss_ghost_resid = mse_loss_ghost_resid.mean() + + mse_loss = mse_loss.mean() sparsity = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,)) l1_loss = self.l1_coefficient * sparsity @@ -155,7 +158,6 @@ def initialize_b_dec_with_geometric_median(self, activation_store): skip_typechecks=True, maxiter=100, per_component=False).median - previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1) distances = torch.norm(all_activations - out, dim=-1) @@ -181,238 +183,7 @@ def initialize_b_dec_with_mean(self, activation_store): print(f"New distances: {distances.median(0).values.mean().item()}") self.b_dec.data = out.to(self.dtype).to(self.device) - - - @torch.no_grad() - def resample_neurons_l2( - self, - x: Float[Tensor, "batch_size n_hidden"], - feature_sparsity: Float[Tensor, "n_hidden_ae"], - optimizer: torch.optim.Optimizer, - ) -> None: - ''' - Resamples neurons that have been dead for `dead_neuron_window` steps, according to `frac_active`. - - I'll probably break this now and fix it later! - ''' - - feature_reinit_scale = self.cfg.feature_reinit_scale - - sae_out, _, _, _, _ = self.forward(x) - per_token_l2_loss = (sae_out - x).pow(2).sum(dim=-1).squeeze() - - # Find the dead neurons in this instance. If all neurons are alive, continue - is_dead = (feature_sparsity < self.cfg.dead_feature_threshold) - dead_neurons = torch.nonzero(is_dead).squeeze(-1) - alive_neurons = torch.nonzero(~is_dead).squeeze(-1) - n_dead = dead_neurons.numel() - - if n_dead == 0: - return 0 # If there are no dead neurons, we don't need to resample neurons - - # Compute L2 loss for each element in the batch - # TODO: Check whether we need to go through more batches as features get sparse to find high l2 loss examples. - if per_token_l2_loss.max() < 1e-6: - return 0 # If we have zero reconstruction loss, we don't need to resample neurons - - # Draw `n_hidden_ae` samples from [0, 1, ..., batch_size-1], with probabilities proportional to l2_loss squared - per_token_l2_loss = per_token_l2_loss.to(torch.float32) # wont' work with bfloat16 - distn = Categorical(probs = per_token_l2_loss.pow(2) / (per_token_l2_loss.pow(2).sum())) - replacement_indices = distn.sample((n_dead,)) # shape [n_dead] - - # Index into the batch of hidden activations to get our replacement values - replacement_values = (x - self.b_dec)[replacement_indices] # shape [n_dead n_input_ae] - - # unit norm - replacement_values = (replacement_values / (replacement_values.norm(dim=1, keepdim=True) + 1e-8)) - - # St new decoder weights - self.W_dec.data[is_dead, :] = replacement_values - - # Get the norm of alive neurons (or 1.0 if there are no alive neurons) - W_enc_norm_alive_mean = 1.0 if len(alive_neurons) == 0 else self.W_enc[:, alive_neurons].norm(dim=0).mean().item() - - # Lastly, set the new weights & biases - self.W_enc.data[:, is_dead] = (replacement_values * W_enc_norm_alive_mean * feature_reinit_scale).T - self.b_enc.data[is_dead] = 0.0 - - - # reset the Adam Optimiser for every modified weight and bias term - # Reset all the Adam parameters - for dict_idx, (k, v) in enumerate(optimizer.state.items()): - for v_key in ["exp_avg", "exp_avg_sq"]: - if dict_idx == 0: - assert k.data.shape == (self.d_in, self.d_sae) - v[v_key][:, is_dead] = 0.0 - elif dict_idx == 1: - assert k.data.shape == (self.d_sae,) - v[v_key][is_dead] = 0.0 - elif dict_idx == 2: - assert k.data.shape == (self.d_sae, self.d_in) - v[v_key][is_dead, :] = 0.0 - elif dict_idx == 3: - assert k.data.shape == (self.d_in,) - else: - raise ValueError(f"Unexpected dict_idx {dict_idx}") - - # Check that the opt is really updated - for dict_idx, (k, v) in enumerate(optimizer.state.items()): - for v_key in ["exp_avg", "exp_avg_sq"]: - if dict_idx == 0: - if k.data.shape != (self.d_in, self.d_sae): - print( - "Warning: it does not seem as if resetting the Adam parameters worked, there are shapes mismatches" - ) - if v[v_key][:, is_dead].abs().max().item() > 1e-6: - print( - "Warning: it does not seem as if resetting the Adam parameters worked" - ) - - return n_dead - - @torch.no_grad() - def resample_neurons_anthropic( - self, - dead_neuron_indices, - model, - optimizer, - activation_store): - """ - Arthur's version of Anthropic's feature resampling - procedure. - """ - # collect global loss increases, and input activations - global_loss_increases, global_input_activations = self.collect_anthropic_resampling_losses( - model, activation_store - ) - # sample according to losses - probs = global_loss_increases / global_loss_increases.sum() - sample_indices = torch.multinomial( - probs, - min(len(dead_neuron_indices), probs.shape[0]), - replacement=False, - ) - # if we don't have enough samples for for all the dead neurons, take the first n - if sample_indices.shape[0] < len(dead_neuron_indices): - dead_neuron_indices = dead_neuron_indices[:sample_indices.shape[0]] - - # Replace W_dec with normalized differences in activations - self.W_dec.data[dead_neuron_indices, :] = ( - ( - global_input_activations[sample_indices] - / torch.norm(global_input_activations[sample_indices], dim=1, keepdim=True) - ) - .to(self.dtype) - .to(self.device) - ) - - # Lastly, set the new weights & biases - self.W_enc.data[:, dead_neuron_indices] = self.W_dec.data[dead_neuron_indices, :].T - self.b_enc.data[dead_neuron_indices] = 0.0 - - # Reset the Encoder Weights - if dead_neuron_indices.shape[0] < self.d_sae: - sum_of_all_norms = torch.norm(self.W_enc.data, dim=0).sum() - sum_of_all_norms -= len(dead_neuron_indices) - average_norm = sum_of_all_norms / (self.d_sae - len(dead_neuron_indices)) - self.W_enc.data[:, dead_neuron_indices] *= self.cfg.feature_reinit_scale * average_norm - - # Set biases to resampled value - relevant_biases = self.b_enc.data[dead_neuron_indices].mean() - self.b_enc.data[dead_neuron_indices] = relevant_biases * 0 # bias resample factor (put in config?) - - else: - self.W_enc.data[:, dead_neuron_indices] *= self.cfg.feature_reinit_scale - self.b_enc.data[dead_neuron_indices] = -5.0 - - # TODO: Refactor this resetting to be outside of resampling. - # reset the Adam Optimiser for every modified weight and bias term - # Reset all the Adam parameters - for dict_idx, (k, v) in enumerate(optimizer.state.items()): - for v_key in ["exp_avg", "exp_avg_sq"]: - if dict_idx == 0: - assert k.data.shape == (self.d_in, self.d_sae) - v[v_key][:, dead_neuron_indices] = 0.0 - elif dict_idx == 1: - assert k.data.shape == (self.d_sae,) - v[v_key][dead_neuron_indices] = 0.0 - elif dict_idx == 2: - assert k.data.shape == (self.d_sae, self.d_in) - v[v_key][dead_neuron_indices, :] = 0.0 - elif dict_idx == 3: - assert k.data.shape == (self.d_in,) - else: - raise ValueError(f"Unexpected dict_idx {dict_idx}") - - # Check that the opt is really updated - for dict_idx, (k, v) in enumerate(optimizer.state.items()): - for v_key in ["exp_avg", "exp_avg_sq"]: - if dict_idx == 0: - if k.data.shape != (self.d_in, self.d_sae): - print( - "Warning: it does not seem as if resetting the Adam parameters worked, there are shapes mismatches" - ) - if v[v_key][:, dead_neuron_indices].abs().max().item() > 1e-6: - print( - "Warning: it does not seem as if resetting the Adam parameters worked" - ) - - return - - @torch.no_grad() - def collect_anthropic_resampling_losses(self, model, activation_store): - """ - Collects the losses for resampling neurons (anthropic) - """ - - batch_size = self.cfg.store_batch_size - - # we're going to collect this many forward passes - number_final_activations = self.cfg.resample_batches * batch_size - # but have seq len number of tokens in each - number_activations_total = number_final_activations * self.cfg.context_size - anthropic_iterator = range(0, number_final_activations, batch_size) - anthropic_iterator = tqdm(anthropic_iterator, desc="Collecting losses for resampling...") - - global_loss_increases = torch.zeros((number_final_activations,), dtype=self.dtype, device=self.device) - global_input_activations = torch.zeros((number_final_activations, self.d_in), dtype=self.dtype, device=self.device) - - for refill_idx in anthropic_iterator: - - # get a batch, calculate loss with/without using SAE reconstruction. - batch_tokens = activation_store.get_batch_tokens() - ce_loss_with_recons = self.get_test_loss(batch_tokens, model) - ce_loss_without_recons, normal_activations_cache = model.run_with_cache( - batch_tokens, - names_filter=self.cfg.hook_point, - return_type = "loss", - loss_per_token = True, - ) - # ce_loss_without_recons = model.loss_fn(normal_logits, batch_tokens, True) - # del normal_logits - - normal_activations = normal_activations_cache[self.cfg.hook_point] - if self.cfg.hook_point_head_index is not None: - normal_activations = normal_activations[:,:,self.cfg.hook_point_head_index] - - # calculate the difference in loss - changes_in_loss = ce_loss_with_recons - ce_loss_without_recons - changes_in_loss = changes_in_loss.cpu() - - # sample from the loss differences - probs = F.relu(changes_in_loss) / F.relu(changes_in_loss).sum(dim=1, keepdim=True) - changes_in_loss_dist = Categorical(probs) - samples = changes_in_loss_dist.sample() - - assert samples.shape == (batch_size,), f"{samples.shape=}; {self.cfg.store_batch_size=}" - - end_idx = refill_idx + batch_size - global_loss_increases[refill_idx:end_idx] = changes_in_loss[torch.arange(batch_size), samples] - global_input_activations[refill_idx:end_idx] = normal_activations[torch.arange(batch_size), samples] - - return global_loss_increases, global_input_activations - @torch.no_grad() def get_test_loss(self, batch_tokens, model): """ @@ -439,7 +210,6 @@ def head_replacement_hook(activations, hook): ) return ce_loss_with_recons - @torch.no_grad() def set_decoder_norm_to_unit_norm(self): diff --git a/sae_training/timeit.py b/sae_training/timeit.py deleted file mode 100644 index d09a6db1..00000000 --- a/sae_training/timeit.py +++ /dev/null @@ -1,25 +0,0 @@ -""" - This is a util to time the execution of a function. - - (Has to be a separate file, if you put it in utils.py you get circular imports; need to find a permanent home for it) -""" - -from functools import wraps -import time - -def timeit(func): - """ - Decorator to time a function. - - Taken from https://dev.to/kcdchennai/python-decorator-to-measure-execution-time-54hk - """ - @wraps(func) - def timeit_wrapper(*args, **kwargs): - start_time = time.perf_counter() - result = func(*args, **kwargs) - end_time = time.perf_counter() - total_time = end_time - start_time - print(f'Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') - return result - return timeit_wrapper - \ No newline at end of file diff --git a/sae_training/tmp.py b/sae_training/tmp.py deleted file mode 100644 index e69de29b..00000000 diff --git a/sae_training/toy_model_runner.py b/sae_training/toy_model_runner.py index 87a3c951..f28fd368 100644 --- a/sae_training/toy_model_runner.py +++ b/sae_training/toy_model_runner.py @@ -29,11 +29,11 @@ class SAEToyModelRunnerConfig: l1_coefficient: float = 1e-3 lr: float = 3e-4 train_batch_size: int = 1024 + b_dec_init_method: str = "geometric_median" - # Resampling protocol args - feature_sampling_method: str = "l2" # None or l2 + # Sparsity / Dead Feature Handling + use_ghost_grads: bool = False # not currently implemented, but SAE class expects it. feature_sampling_window: int = 100 - feature_reinit_scale: float = 0.2 dead_feature_window: int = 100 # unless this window is larger feature sampling, dead_feature_threshold: float = 1e-8 @@ -92,12 +92,10 @@ def toy_model_sae_runner(cfg): wandb.init(project=cfg.wandb_project, config=cfg) sparse_autoencoder = train_toy_sae( - model, # need model so we can do evals for neuron resampling sparse_autoencoder, - hidden.detach().squeeze(), + activation_store=hidden.detach().squeeze(), batch_size=cfg.train_batch_size, feature_sampling_window=cfg.feature_sampling_window, - feature_reinit_scale=cfg.feature_reinit_scale, dead_feature_threshold=cfg.dead_feature_threshold, use_wandb=cfg.log_to_wandb, wandb_log_frequency=cfg.wandb_log_frequency, diff --git a/sae_training/toy_models.py b/sae_training/toy_models.py index 1a6b5aeb..ed540ad8 100644 --- a/sae_training/toy_models.py +++ b/sae_training/toy_models.py @@ -149,9 +149,6 @@ def generate_batch(self, batch_size) -> Float[Tensor, "batch_size instances feat batch = t.cat(data, dim=-1) return batch - - - def calculate_loss( self, out: Float[Tensor, "batch instances features"], @@ -344,15 +341,15 @@ def parse_colors_for_superposition_plot( This function unifies them all by turning colors into a list of lists of strings, i.e. one color for each instance & feature. ''' # If colors is a tensor, we assume it's the importances tensor, and we color according to a viridis color scheme - if isinstance(colors, Tensor): - colors = t.broadcast_to(colors, (n_instances, n_feats)) - colors = [ - [helper_get_viridis(v.item()) for v in colors_for_this_instance] - for colors_for_this_instance in colors - ] + # if isinstance(colors, Tensor): + # colors = t.broadcast_to(colors, (n_instances, n_feats)) + # colors = [ + # [helper_get_viridis(v.item()) for v in colors_for_this_instance] + # for colors_for_this_instance in colors + # ] # If colors is a tuple of ints, it's interpreted as number of correlated / anticorrelated pairs - elif isinstance(colors, tuple): + if isinstance(colors, tuple): n_corr, n_anti = colors n_indep = n_feats - 2 * (n_corr - n_anti) colors = [ diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index 81359da6..ecf3c702 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -1,15 +1,13 @@ -from functools import partial -import numpy as np -import plotly_express as px + import torch from torch.optim import Adam from tqdm import tqdm from transformer_lens import HookedTransformer -from transformer_lens.utils import get_act_name import wandb from sae_training.activations_store import ActivationsStore +from sae_training.evals import run_evals from sae_training.optim import get_scheduler from sae_training.sparse_autoencoder import SparseAutoencoder @@ -20,24 +18,17 @@ def train_sae_on_language_model( activation_store: ActivationsStore, batch_size: int = 1024, n_checkpoints: int = 0, - feature_sampling_method: str = "l2", # None, l2, or anthropic feature_sampling_window: int = 1000, # how many training steps between resampling the features / considiring neurons dead - feature_reinit_scale: float = 0.2, # how much to scale the resampled features by dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead - dead_feature_window: int = 2000, # how many training steps before a feature is considered dead use_wandb: bool = False, wandb_log_frequency: int = 50, ): - if feature_sampling_method is not None: - feature_sampling_method = feature_sampling_method.lower() - total_training_tokens = sparse_autoencoder.cfg.total_training_tokens total_training_steps = total_training_tokens // batch_size n_training_steps = 0 n_training_tokens = 0 - n_resampled_neurons = 0 - steps_before_reset = 0 + if n_checkpoints > 0: checkpoint_thresholds = list(range(0, total_training_tokens, total_training_tokens // n_checkpoints))[1:] @@ -66,51 +57,6 @@ def train_sae_on_language_model( # Make sure the W_dec is still zero-norm sparse_autoencoder.set_decoder_norm_to_unit_norm() - - if (feature_sampling_method=="anthropic") and ((n_training_steps + 1) % dead_feature_window == 0): - - feature_sparsity = act_freq_scores / n_frac_active_tokens - - # if reset criterion is frequency in window, then then use that to generate indices. - if sparse_autoencoder.cfg.dead_feature_estimation_method == "no_fire": - dead_neuron_indices = (act_freq_scores == 0).nonzero(as_tuple=False)[:, 0] - elif sparse_autoencoder.cfg.dead_feature_estimation_method == "frequency": - dead_neuron_indices = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0] - - if len(dead_neuron_indices) > 0: - - if len(dead_neuron_indices) > sparse_autoencoder.cfg.resample_batches * sparse_autoencoder.cfg.store_batch_size: - print("Warning: more dead neurons than number of tokens. Consider sampling more tokens when resampling.") - - sparse_autoencoder.resample_neurons_anthropic( - dead_neuron_indices, - model, - optimizer, - activation_store - ) - - if use_wandb: - n_resampled_neurons = min(len(dead_neuron_indices), sparse_autoencoder.cfg.store_batch_size * sparse_autoencoder.cfg.resample_batches) - wandb.log( - { - "metrics/n_resampled_neurons": n_resampled_neurons, - }, - step=n_training_steps, - ) - - # for now, we'll hardcode this. - current_lr = scheduler.get_last_lr()[0] - reduced_lr = current_lr / 10_000 - increment = (current_lr - reduced_lr) / 10_000 - optimizer.param_groups[0]['lr'] = reduced_lr - steps_before_reset = 10_000 - else: - print("No dead neurons, skipping resampling") - - # Resample dead neurons - if (feature_sampling_method == "l2") and ((n_training_steps + 1) % dead_feature_window == 0): - print("no l2 resampling currently. Please use anthropic resampling") - # after resampling, reset the sparsity: if (n_training_steps + 1) % feature_sampling_window == 0: @@ -130,16 +76,11 @@ def train_sae_on_language_model( act_freq_scores = torch.zeros(sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device) n_frac_active_tokens = 0 - - - if (steps_before_reset > 0) and n_training_steps > 0: - steps_before_reset -= 1 - optimizer.param_groups[0]['lr'] += increment - if steps_before_reset == 0: - optimizer.param_groups[0]['lr'] = current_lr - else: + scheduler.step() scheduler.step() + scheduler.step() + optimizer.zero_grad() ghost_grad_neuron_mask = (n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window).bool() @@ -168,7 +109,7 @@ def train_sae_on_language_model( current_learning_rate = optimizer.param_groups[0]["lr"] per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze() - total_variance = sae_in.pow(2).sum(-1) + total_variance = (sae_in-sae_in.mean(0)).pow(2).sum(-1) explained_variance = 1 - per_token_l2_loss/total_variance wandb.log( @@ -227,6 +168,7 @@ def train_sae_on_language_model( path = f"{sparse_autoencoder.cfg.checkpoint_path}/{n_training_tokens}_{sparse_autoencoder.get_name()}.pt" log_feature_sparsity_path = f"{sparse_autoencoder.cfg.checkpoint_path}/{n_training_tokens}_{sparse_autoencoder.get_name()}_log_feature_sparsity.pt" sparse_autoencoder.save_model(path) + log_feature_sparsity = torch.log10(feature_sparsity + 1e-10).detach().cpu() torch.save(log_feature_sparsity, log_feature_sparsity_path) checkpoint_thresholds.pop(0) if len(checkpoint_thresholds) == 0: @@ -259,200 +201,3 @@ def train_sae_on_language_model( return sparse_autoencoder - - -@torch.no_grad() -def run_evals(sparse_autoencoder: SparseAutoencoder, activation_store: ActivationsStore, model: HookedTransformer, n_training_steps: int): - - hook_point = sparse_autoencoder.cfg.hook_point - hook_point_layer = sparse_autoencoder.cfg.hook_point_layer - hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index - - ### Evals - eval_tokens = activation_store.get_batch_tokens() - - # Get Reconstruction Score - recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(sparse_autoencoder, model, activation_store, eval_tokens) - - # get cache - _, cache = model.run_with_cache(eval_tokens, prepend_bos=False, names_filter=[get_act_name("pattern", hook_point_layer), hook_point]) - - # get act - if sparse_autoencoder.cfg.hook_point_head_index is not None: - original_act = cache[sparse_autoencoder.cfg.hook_point][:,:,sparse_autoencoder.cfg.hook_point_head_index] - else: - original_act = cache[sparse_autoencoder.cfg.hook_point] - - sae_out, feature_acts, _, _, _, _ = sparse_autoencoder( - original_act - ) - patterns_original = cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu() - del cache - - if "cuda" in str(model.cfg.device): - torch.cuda.empty_cache() - - l2_norm_in = torch.norm(original_act, dim=-1) - l2_norm_out = torch.norm(sae_out, dim=-1) - l2_norm_ratio = l2_norm_out / l2_norm_in - - wandb.log( - { - - # l2 norms - "metrics/l2_norm": l2_norm_out.mean().item(), - "metrics/l2_ratio": l2_norm_ratio.mean().item(), - - # CE Loss - "metrics/CE_loss_score": recons_score, - "metrics/ce_loss_without_sae": ntp_loss, - "metrics/ce_loss_with_sae": recons_loss, - "metrics/ce_loss_with_ablation": zero_abl_loss, - - }, - step=n_training_steps, - ) - - head_index = sparse_autoencoder.cfg.hook_point_head_index - - def standard_replacement_hook(activations, hook): - activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype) - return activations - - def head_replacement_hook(activations, hook): - new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype) - activations[:,:,head_index] = new_actions - return activations - - head_index = sparse_autoencoder.cfg.hook_point_head_index - replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook - - # get attn when using reconstructed activations - with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]): - _, new_cache = model.run_with_cache(eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]) - patterns_reconstructed = new_cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu() - del new_cache - - # get attn when using reconstructed activations - with model.hooks(fwd_hooks=[(hook_point, partial(zero_ablate_hook))]): - _, zero_ablation_cache = model.run_with_cache(eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]) - patterns_ablation = zero_ablation_cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu() - del zero_ablation_cache - - - # Visualizations to show L0 / MSE distributions - # l0 = (feature_acts > 0).float().sum(-1) - # per_token_l2_loss = (sae_out - original_act).pow(2).sum(dim=-1).squeeze() - - # fig = px.scatter( - # x = per_token_l2_loss.flatten().cpu().numpy(), - # y = l0.flatten().cpu().numpy(), - # color = np.arange(per_token_l2_loss.shape[1]).repeat(per_token_l2_loss.shape[0]), - # opacity=0.5, - # labels = {"color": "position", "x": "MSE Loss", "y": "L0"}, - # title = "L0 vs MSE Loss", - # marginal_x="histogram", - # marginal_y="histogram", - # ) - # wandb.log({"plots/l0_vs_mse_loss": wandb.Plotly(fig)}, step = n_training_steps) - - # fig = px.scatter( - # x = per_token_l2_loss.flatten().cpu().numpy(), - # y = l2_norm_in.flatten().cpu().numpy(), - # color = np.arange(per_token_l2_loss.shape[1]).repeat(per_token_l2_loss.shape[0]), - # opacity=0.5, - # labels={"color": "position", "x": "MSE Loss", "y": "L2 Norm"}, - # title = "L2 Norm vs MSE Loss", - # marginal_x="histogram", - # marginal_y="histogram", - # ) - # wandb.log({"plots/l2_norm_vs_mse_loss": wandb.Plotly(fig)}, step = n_training_steps) - - # if dealing with a head SAE, do the head metrics. - if sparse_autoencoder.cfg.hook_point_head_index: - - # show patterns before/after - # fig_patterns_original = px.imshow(patterns_original[0].numpy(), title="original attn scores", - # color_continuous_midpoint=0, color_continuous_scale="RdBu") - # fig_patterns_original.update_layout(coloraxis_showscale=False) # hide colorbar - # wandb.log({"attention/patterns_original": wandb.Plotly(fig_patterns_original)}, step = n_training_steps) - # fig_patterns_reconstructed = px.imshow(patterns_reconstructed[0].numpy(), title="reconstructed attn scores", - # color_continuous_midpoint=0, color_continuous_scale="RdBu") - # fig_patterns_reconstructed.update_layout(coloraxis_showscale=False) # hide colorbar - # wandb.log({"attention/patterns_reconstructed": wandb.Plotly(fig_patterns_reconstructed)}, step = n_training_steps) - - kl_result_reconstructed = kl_divergence_attention(patterns_original, patterns_reconstructed) - kl_result_reconstructed = kl_result_reconstructed.sum(dim=-1).numpy() - # print(kl_result.mean().item()) - # px.imshow(kl_result, title="KL Divergence", width=800, height=800, - # color_continuous_midpoint=0, color_continuous_scale="RdBu").show() - # px.histogram(kl_result.flatten()).show() - # px.line(kl_result.mean(0), title="KL Divergence by Position").show() - - kl_result_ablation = kl_divergence_attention(patterns_original, patterns_ablation) - kl_result_ablation = kl_result_ablation.sum(dim=-1).numpy() - # print(kl_result.mean().item()) - # # px.imshow(kl_result, title="KL Divergence", width=800, height=800, - # # color_continuous_midpoint=0, color_continuous_scale="RdBu").show() - # px.histogram(kl_result.flatten()).show() - # px.line(kl_result.mean(0), title="KL Divergence by Position").show() - - wandb.log( - { - - "metrics/kldiv_reconstructed": kl_result_reconstructed.mean().item(), - "metrics/kldiv_ablation": kl_result_ablation.mean().item(), - - }, - step=n_training_steps, - ) - -@torch.no_grad() -def get_recons_loss(sparse_autoencoder, model, activation_store, batch_tokens): - hook_point = activation_store.cfg.hook_point - loss = model(batch_tokens, return_type="loss") - - head_index = sparse_autoencoder.cfg.hook_point_head_index - - def standard_replacement_hook(activations, hook): - activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype) - return activations - - def head_replacement_hook(activations, hook): - new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype) - activations[:,:,head_index] = new_actions - return activations - - replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook - recons_loss = model.run_with_hooks( - batch_tokens, - return_type="loss", - fwd_hooks=[(hook_point, partial(replacement_hook))], - ) - - zero_abl_loss = model.run_with_hooks( - batch_tokens, return_type="loss", fwd_hooks=[(hook_point, zero_ablate_hook)] - ) - - score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss) - - return score, loss, recons_loss, zero_abl_loss - - -def mean_ablate_hook(mlp_post, hook): - mlp_post[:] = mlp_post.mean([0, 1]).to(mlp_post.dtype) - return mlp_post - - -def zero_ablate_hook(mlp_post, hook): - mlp_post[:] = 0.0 - return mlp_post - - -def kl_divergence_attention(y_true, y_pred): - - # Compute log probabilities for KL divergence - log_y_true = torch.log2(y_true + 1e-10) - log_y_pred = torch.log2(y_pred + 1e-10) - - return y_true * (log_y_true - log_y_pred) \ No newline at end of file diff --git a/sae_training/train_sae_on_toy_model.py b/sae_training/train_sae_on_toy_model.py index 5bbd937b..7ad3c29e 100644 --- a/sae_training/train_sae_on_toy_model.py +++ b/sae_training/train_sae_on_toy_model.py @@ -9,15 +9,11 @@ def train_toy_sae( - model: ToyModel, sparse_autoencoder: SparseAutoencoder, activation_store, batch_size: int = 1024, - total_training_tokens: int = 1024 * 10_000, - feature_sampling_method: str = "l2", # None, l2, or anthropic feature_sampling_window: int = 100, # how many training steps between resampling the features / considiring neurons dead dead_feature_window: int = 2000, # how many training steps before a feature is considered dead - feature_reinit_scale: float = 0.2, # how much to scale the resampled features by dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead use_wandb: bool = False, wandb_log_frequency: int = 50, @@ -33,7 +29,6 @@ def train_toy_sae( n_training_steps = 0 n_training_tokens = 0 - n_resampled_neurons = 0 pbar = tqdm(dataloader, desc="Training SAE") for _, batch in enumerate(pbar): @@ -42,38 +37,18 @@ def train_toy_sae( # Make sure the W_dec is still zero-norm sparse_autoencoder.set_decoder_norm_to_unit_norm() - # Resample dead neurons - if (feature_sampling_method is not None) and ((n_training_steps + 1) % dead_feature_window == 0): - - # Get the fraction of neurons active in the previous window - frac_active_in_window = torch.stack(frac_active_list[-dead_feature_window:], dim=0) - feature_sparsity = frac_active_in_window.sum(0) / ( - dead_feature_window * batch_size - ) - - # Compute batch of hidden activations which we'll use in resampling - resampling_batch = model.generate_batch(batch_size) - - # Our version of running the model - hidden = einops.einsum( - resampling_batch, - model.W, - "batch_size instances features, instances hidden features -> batch_size instances hidden", - ) - - # Resample - n_resampled_neurons = sparse_autoencoder.resample_neurons( - hidden, feature_sparsity, feature_reinit_scale - ) - - # Update learning rate here if using scheduler. - # Forward and Backward Passes optimizer.zero_grad() - _, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(batch) + sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(batch) + loss.backward() + sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() + optimizer.step() + n_training_tokens += batch_size with torch.no_grad(): + + # Calculate the sparsities, and add it to a list act_freq_scores = (feature_acts.abs() > 0).float().sum(0) frac_active_list.append(act_freq_scores) @@ -93,57 +68,60 @@ def train_toy_sae( ) - l0 = (feature_acts > 0).float().sum(1).mean() + l0 = (feature_acts > 0).float().sum(-1).mean() + current_learning_rate = optimizer.param_groups[0]["lr"] l2_norm = torch.norm(feature_acts, dim=1).mean() + + l2_norm_in = torch.norm(batch, dim=-1) + l2_norm_out = torch.norm(sae_out, dim=-1) + l2_norm_ratio = l2_norm_out / (1e-6+l2_norm_in) + if use_wandb and ((n_training_steps + 1) % wandb_log_frequency == 0): wandb.log( { + "details/n_training_tokens": n_training_tokens, "losses/mse_loss": mse_loss.item(), "losses/l1_loss": l1_loss.item(), "losses/overall_loss": loss.item(), "metrics/l0": l0.item(), "metrics/l2": l2_norm.item(), - "metrics/below_1e-5": (feature_sparsity < 1e-5) + "metrics/l2_ratio": l2_norm_ratio.mean().item(), + "sparsity/below_1e-5": (feature_sparsity < 1e-5) .float() .mean() .item(), - "metrics/below_1e-6": (feature_sparsity < 1e-6) + "sparsity/below_1e-6": (feature_sparsity < 1e-6) .float() .mean() .item(), - "metrics/n_dead_features": ( + "sparsity/n_dead_features": ( feature_sparsity < dead_feature_threshold ) .float() .mean() .item(), - "metrics/n_resampled_neurons": n_resampled_neurons, - }, - step=n_training_steps, - - ) - - if (n_training_steps + 1) % (wandb_log_frequency * 100) == 0: - log_feature_sparsity = torch.log10(feature_sparsity + 1e-8) - wandb.log( - { - "plots/feature_density_histogram": wandb.Histogram( - log_feature_sparsity.tolist() - ), }, step=n_training_steps, ) + + if (n_training_steps + 1) % (wandb_log_frequency * 100) == 0: + log_feature_sparsity = torch.log10(feature_sparsity + 1e-8) + wandb.log( + { + "plots/feature_density_histogram": wandb.Histogram( + log_feature_sparsity.tolist() + ), + }, + step=n_training_steps, + ) pbar.set_description( f"{n_training_steps}| MSE Loss {mse_loss.item():.3f} | L0 {l0.item():.3f}" ) pbar.update(batch_size) - loss.backward() - sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() - optimizer.step() - + # If we did checkpointing we'd do it here. n_training_steps += 1 diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index 4bf79dcd..28f064ee 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -4,8 +4,14 @@ from sae_training.lm_runner import language_model_sae_runner -def test_language_model_sae_runner_mlp_out(): +def test_language_model_sae_runner(): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" cfg = LanguageModelSAERunnerConfig( @@ -18,199 +24,43 @@ def test_language_model_sae_runner_mlp_out(): is_dataset_tokenized=True, # SAE Parameters - expansion_factor = 64, + expansion_factor = 16, + b_dec_init_method="mean", # not ideal but quicker when testing code. # Training Parameters lr = 1e-4, - l1_coefficient = 3e-4, + l1_coefficient = 3e-4, train_batch_size = 4096, context_size = 128, # Activation Store Parameters n_batches_in_buffer = 24, - total_training_tokens = 5_000_00 * 100, + total_training_tokens = 1_000_000 * 10, store_batch_size = 32, # Resampling protocol - feature_sampling_method = 'l2', - feature_sampling_window = 2500, - feature_reinit_scale = 0.2, - dead_feature_window=1250, - dead_feature_threshold = 1e-8, - - # WANDB - log_to_wandb = True, - wandb_project= "mats_sae_training_language_models", - wandb_entity = None, - - # Misc - device = "mps", - seed = 42, - n_checkpoints = 5, - checkpoint_path = "checkpoints", - dtype = torch.float32, - ) - - sparse_autoencoder = language_model_sae_runner(cfg) - - assert sparse_autoencoder is not None - - - -def test_language_model_sae_runner_resid_pre(): - - cfg = LanguageModelSAERunnerConfig( - - # Data Generating Function (Model + Training Distibuion) - model_name = "gelu-2l", - hook_point = "blocks.0.hook_resid_mid", - hook_point_layer = 0, - d_in = 512, - dataset_path = "NeelNanda/c4-tokenized-2b", - is_dataset_tokenized=True, - - # SAE Parameters - expansion_factor = 64, + use_ghost_grads=True, - # Training Parameters - lr = 1e-4, - l1_coefficient = 1e-4, - train_batch_size = 4096, - context_size = 128, - - # Activation Store Parameters - n_batches_in_buffer = 24, - total_training_tokens = 5_000_00 * 100, - store_batch_size = 32, - - # Resampling protocol - feature_sampling_method = 'l2', - feature_sampling_window = 1000, - feature_reinit_scale = 0.2, + feature_sampling_window = 3000, # in steps + dead_feature_window=5000, dead_feature_threshold = 1e-8, # WANDB log_to_wandb = True, - wandb_project= "mats_sae_training_language_models", + wandb_project= "mats_sae_training_benchmarks", wandb_entity = None, # Misc - device = "cuda", + device = device, seed = 42, n_checkpoints = 5, checkpoint_path = "checkpoints", dtype = torch.float32, ) - trained_sae = language_model_sae_runner(cfg) - - assert trained_sae is not None - - - -def test_language_model_sae_runner_hook_q(): - - - # for l1_coefficient in [9e-4,8e-4,7e-4]: - cfg = LanguageModelSAERunnerConfig( - - # Data Generating Function (Model + Training Distibuion) - model_name = "gpt2-small", - hook_point = "blocks.10.attn.hook_q", - hook_point_layer = 10, - hook_point_head_index=7, - d_in = 64, - dataset_path = "Skylion007/openwebtext", - is_dataset_tokenized=False, - use_cached_activations=True, - cached_activations_path="../activations/", - - # SAE Parameters - expansion_factor = 64, # determines the dimension of the SAE. (64*64 = 4096, 64*4*64 = 32768) - - # Training Parameters - lr = 1e-3, - l1_coefficient = 2e-4, - # lr_scheduler_name="LinearWarmupDecay", - lr_warm_up_steps=2200, - train_batch_size = 4096, - context_size = 128, - - # Activation Store Parameters - n_batches_in_buffer = 512, - total_training_tokens = 3_000_000, - store_batch_size = 32, - - # Resampling protocol - feature_sampling_method = 'l2', - feature_sampling_window = 1000, - feature_reinit_scale = 0.2, - dead_feature_window=200, - dead_feature_threshold = 5e-6, - - # WANDB - log_to_wandb = True, - wandb_project= "mats_sae_training_gpt2_small_hook_q_dev", - wandb_entity = None, - wandb_log_frequency=5, - - # Misc - device = "mps", - seed = 42, - n_checkpoints = 0, - checkpoint_path = "checkpoints", - dtype = torch.float32, - ) - - sparse_autoencoder = language_model_sae_runner(cfg) assert sparse_autoencoder is not None + # know whether or not this works by looking at the dashbaord! -def test_language_model_sae_runner_not_tokenized(): - - cfg = LanguageModelSAERunnerConfig( - - # Data Generating Function (Model + Training Distibuion) - model_name = "gelu-2l", - hook_point = "blocks.1.hook_mlp_out", - hook_point_layer = 0, - d_in = 512, - dataset_path = "roneneldan/TinyStories", - is_dataset_tokenized=False, - - # SAE Parameters - expansion_factor = 64, # determines the dimension of the SAE. - - # Training Parameters - lr = 1e-4, - l1_coefficient = 1e-4, - train_batch_size = 4096, - context_size = 128, - - # Activation Store Parameters - n_batches_in_buffer = 8, - total_training_tokens = 25_000_00 * 60, - store_batch_size = 32, - - # Resampling protocol - feature_sampling_window = 1000, - feature_reinit_scale = 0.2, - dead_feature_threshold = 1e-8, - - # WANDB - log_to_wandb = True, - wandb_project= "mats_sae_training_language_models", - wandb_entity = None, - - # Misc - device = "mps", - seed = 42, - checkpoint_path = "checkpoints", - dtype = torch.float32, - ) - - trained_sae = language_model_sae_runner(cfg) - - assert trained_sae is not None diff --git a/tests/benchmark/test_toy_model_sae_runner.py b/tests/benchmark/test_toy_model_sae_runner.py index e1dbb0cd..8bf762bf 100644 --- a/tests/benchmark/test_toy_model_sae_runner.py +++ b/tests/benchmark/test_toy_model_sae_runner.py @@ -1,15 +1,24 @@ import pytest +import torch from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner # @pytest.mark.skip(reason="I (joseph) broke this at some point, on my to do list to fix.") def test_toy_model_sae_runner(): + + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + cfg = SAEToyModelRunnerConfig( # Model Details - n_features=10, - n_hidden=2, + n_features=100, + n_hidden=10, n_correlated_pairs=0, n_anticorrelated_pairs=0, feature_probability=0.025, @@ -17,20 +26,22 @@ def test_toy_model_sae_runner(): # SAE Parameters d_sae=10, + lr = 3e-4, l1_coefficient=0.001, + use_ghost_grads=False, + b_dec_init_method="mean", # SAE Train Config train_batch_size=1028, feature_sampling_window=3_000, dead_feature_window=1_000, - feature_reinit_scale=0.5, - total_training_tokens=4096*300, + total_training_tokens=4096*1000, # Other parameters log_to_wandb=True, - wandb_project="sae-training-test", + wandb_project= "mats_sae_training_benchmarks_toy", wandb_log_frequency=5, - device="mps", + device=device, ) trained_sae = toy_model_sae_runner(cfg) diff --git a/tests/unit/test_sparse_autoencoder.py b/tests/unit/test_sparse_autoencoder.py index a24301be..7224893a 100644 --- a/tests/unit/test_sparse_autoencoder.py +++ b/tests/unit/test_sparse_autoencoder.py @@ -36,7 +36,6 @@ def cfg(): mock_config.context_size = 64 mock_config.feature_sampling_method = None mock_config.feature_sampling_window = 50 - mock_config.resample_batches = 4 mock_config.feature_reinit_scale = 0.1 mock_config.dead_feature_threshold = 1e-7 mock_config.n_batches_in_buffer = 10 @@ -187,8 +186,8 @@ def test_sparse_autoencoder_forward(sparse_autoencoder): assert l1_loss.shape == () assert torch.allclose(loss, mse_loss + l1_loss) - - expected_mse_loss = (torch.pow((sae_out-x.float()), 2) / (x**2).sum(dim=-1, keepdim=True).sqrt()).mean() + x_centred = x - x.mean(dim=0, keepdim=True) + expected_mse_loss = (torch.pow((sae_out-x.float()), 2) / (x_centred**2).sum(dim=-1, keepdim=True).sqrt()).mean() assert torch.allclose(mse_loss, expected_mse_loss) expected_l1_loss = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,)) assert torch.allclose(l1_loss, sparse_autoencoder.l1_coefficient * expected_l1_loss) From 91aca9142459e2a72478d4546ee0fa5a2910c161 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Wed, 7 Feb 2024 18:52:09 -0800 Subject: [PATCH 2/9] add ci --- .github/workflows/tests.yaml | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 .github/workflows/tests.yaml diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 00000000..06da723a --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,56 @@ +name: build + +on: + push: + paths-ignore: + - '.devcontainer/**' + - '.github/**' + - '.vscode/**' + - '.gitignore' + - '*.md' + pull_request: + branches: + - main + paths-ignore: + - '.devcontainer/**' + - '.github/**' + - '.vscode/**' + - '.gitignore' + - '*.md' + # Allow this workflow to be called from other workflows + workflow_call: + inputs: + # Requires at least one input to be valid, but in practice we don't need any + dummy: + type: string + required: false + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Run Unit Tests + run: | + make unit-test From 9f3f1c87ceed84afefef5bc6d4b519d538bcac91 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Wed, 7 Feb 2024 18:53:28 -0800 Subject: [PATCH 3/9] yml not yaml --- .github/workflows/{tests.yaml => tests.yml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename .github/workflows/{tests.yaml => tests.yml} (100%) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yml similarity index 100% rename from .github/workflows/tests.yaml rename to .github/workflows/tests.yml From 7fd0e0c3d3f951ae18b6b101043ba4a36c9933c4 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Wed, 7 Feb 2024 18:55:50 -0800 Subject: [PATCH 4/9] try adding this branch listed specifically --- .github/workflows/tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 06da723a..21f5fd89 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,6 +2,9 @@ name: build on: push: + branches: + - main + - clean_up_repo paths-ignore: - '.devcontainer/**' - '.github/**' From 912a748f7b5bff0193a997ee7369c312f073c35f Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Wed, 7 Feb 2024 18:57:18 -0800 Subject: [PATCH 5/9] dummy file change --- sae_training/sparse_autoencoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sae_training/sparse_autoencoder.py b/sae_training/sparse_autoencoder.py index 967ffe2a..1096e595 100644 --- a/sae_training/sparse_autoencoder.py +++ b/sae_training/sparse_autoencoder.py @@ -101,6 +101,8 @@ def forward(self, x, dead_neuron_mask = None): x_centred = x - x.mean(dim=0, keepdim=True) mse_loss = (torch.pow((sae_out-x.float()), 2) / (x_centred**2).sum(dim=-1, keepdim=True).sqrt()) + + mse_loss_ghost_resid = torch.tensor(0.0, dtype=self.dtype, device=self.device) # gate on config and training so evals is not slowed down. if self.cfg.use_ghost_grads and self.training and dead_neuron_mask.sum() > 0: From 479765bc9ecc4f04fe76e1f4f447d0d0281e9045 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Wed, 7 Feb 2024 19:11:05 -0800 Subject: [PATCH 6/9] black format and linting --- .flake8 | 7 + .gitignore | 2 +- .pre-commit-config.yaml | 25 + .pylintrc | 2 +- .vscode/settings.json | 2 +- requirements.txt | 2 +- sae_analysis/dashboard_runner.py | 283 +++++---- sae_analysis/visualizer/README.md | 2 +- sae_analysis/visualizer/css/general.css | 2 +- sae_analysis/visualizer/css/sequences.css | 3 - sae_analysis/visualizer/css/tables.css | 2 +- sae_analysis/visualizer/data_fns.py | 589 +++++++++++------- .../visualizer/html/frequency_histogram.html | 2 +- .../visualizer/html/hovertext_script.html | 2 +- .../visualizer/html/logit_table_template.html | 2 +- .../visualizer/html/logits_histogram.html | 2 +- .../visualizer/html/token_template.html | 2 +- sae_analysis/visualizer/html_fns.py | 157 +++-- sae_analysis/visualizer/model_fns.py | 44 +- sae_analysis/visualizer/utils_fns.py | 91 +-- sae_training/activations_store.py | 163 ++--- sae_training/cache_activations_runner.py | 38 +- sae_training/config.py | 5 +- sae_training/evals.py | 123 ++-- sae_training/geom_median/setup.py | 4 +- .../geom_median/src/geom_median/__init__.py | 1 - .../src/geom_median/numpy/__init__.py | 2 +- .../geom_median/src/geom_median/numpy/main.py | 71 ++- .../src/geom_median/numpy/utils.py | 53 +- .../src/geom_median/numpy/weiszfeld_array.py | 21 +- .../numpy/weiszfeld_list_of_array.py | 23 +- .../src/geom_median/torch/__init__.py | 2 +- .../geom_median/src/geom_median/torch/main.py | 73 ++- .../src/geom_median/torch/utils.py | 53 +- .../src/geom_median/torch/weiszfeld_array.py | 28 +- .../torch/weiszfeld_list_of_array.py | 30 +- sae_training/lm_runner.py | 42 +- sae_training/optim.py | 28 +- sae_training/sparse_autoencoder.py | 159 ++--- sae_training/toy_model_runner.py | 32 +- sae_training/toy_models.py | 301 ++++++--- sae_training/train_sae_on_language_model.py | 99 +-- sae_training/train_sae_on_toy_model.py | 14 +- sae_training/utils.py | 94 +-- scripts/generate_dashboards.py | 305 +++++---- .../test_language_model_sae_runner.py | 63 +- tests/benchmark/test_toy_model_sae_runner.py | 13 +- tests/unit/test_activations_store.py | 71 +-- tests/unit/test_sparse_autoencoder.py | 119 ++-- tests/unit/test_utils.py | 25 +- 50 files changed, 1921 insertions(+), 1357 deletions(-) create mode 100644 .flake8 create mode 100644 .pre-commit-config.yaml diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..138a4973 --- /dev/null +++ b/.flake8 @@ -0,0 +1,7 @@ +[flake8] +ignore = E203, E266, E501, W503 +max-line-length = 79 +max-complexity = 10 +select = E9, F63, F7, F82 +show-source = true +statistics = true diff --git a/.gitignore b/.gitignore index 29ebafe0..754fd271 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,4 @@ activations/ *.DS_Store feature_dashboards/ -research/ \ No newline at end of file +research/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..c45c4453 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: [--maxkb=250000] +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black +- repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + additional_dependencies: [ + 'flake8-blind-except', + 'flake8-docstrings', + 'flake8-bugbear', + 'flake8-comprehensions', + 'flake8-docstrings', + 'flake8-implicit-str-concat', + 'pydocstyle>=5.0.0', + ] diff --git a/.pylintrc b/.pylintrc index 0127ca79..d4edf945 100644 --- a/.pylintrc +++ b/.pylintrc @@ -16,4 +16,4 @@ default-docstring-type = numpy max-line-length = 100 [MESSAGES CONTROL] -disable = C0330, C0326, C0199, C0411, C103, C0303, C0304 \ No newline at end of file +disable = C0330, C0326, C0199, C0411, C103, C0303, C0304 diff --git a/.vscode/settings.json b/.vscode/settings.json index 6dedeb83..cdf373ce 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,4 +17,4 @@ "black" ], "editor.defaultFormatter": "mikoz.black-py", -} \ No newline at end of file +} diff --git a/requirements.txt b/requirements.txt index 7737a85c..6c962a30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ black==23.11.0 pytest==7.4.3 pytest-cov==4.1.0 pre-commit==3.6.0 -git+https://github.com/callummcdougall/eindex.git \ No newline at end of file +git+https://github.com/callummcdougall/eindex.git diff --git a/sae_analysis/dashboard_runner.py b/sae_analysis/dashboard_runner.py index a2d1b9ea..3ff2bc7b 100644 --- a/sae_analysis/dashboard_runner.py +++ b/sae_analysis/dashboard_runner.py @@ -22,31 +22,27 @@ from sae_training.utils import LMSparseAutoencoderSessionloader -class DashboardRunner(): - +class DashboardRunner: def __init__( self, sae_path: str = None, dashboard_parent_folder: str = "./feature_dashboards", wandb_artifact_path: str = None, init_session: bool = True, - # token pars n_batches_to_sample_from: int = 2**12, - n_prompts_to_select: int = 4096*6, - + n_prompts_to_select: int = 4096 * 6, # sampling pars n_features_at_a_time: int = 1024, max_batch_size: int = 256, buffer_tokens: int = 8, - # util pars use_wandb: bool = False, continue_existing_dashboard: bool = True, final_index: int = None, ): - ''' - # # test it + """ + # # test it # runner = DashboardRunner( # sae_path = None, @@ -64,11 +60,10 @@ def __init__( # runner.run() - - ''' - - if wandb_artifact_path is not None: + """ + + if wandb_artifact_path is not None: artifact_dir = f"artifacts/{wandb_artifact_path.split('/')[2]}" if not os.path.exists(artifact_dir): print("Downloading artifact") @@ -77,92 +72,106 @@ def __init__( artifact_dir = artifact.download() path_to_artifact = f"{artifact_dir}/{os.listdir(artifact_dir)[0]}" # feature sparsity - feature_sparsity_path = self.get_feature_sparsity_path(wandb_artifact_path) + feature_sparsity_path = self.get_feature_sparsity_path( + wandb_artifact_path + ) artifact = run.use_artifact(feature_sparsity_path) artifact_dir = artifact.download() # add it as a property - self.feature_sparsity = torch.load(f"{artifact_dir}/{os.listdir(artifact_dir)[0]}") + self.feature_sparsity = torch.load( + f"{artifact_dir}/{os.listdir(artifact_dir)[0]}" + ) else: print("Artifact already downloaded") path_to_artifact = f"{artifact_dir}/{os.listdir(artifact_dir)[0]}" - - feature_sparsity_path = self.get_feature_sparsity_path(wandb_artifact_path) + + feature_sparsity_path = self.get_feature_sparsity_path( + wandb_artifact_path + ) artifact_dir = f"artifacts/{feature_sparsity_path.split('/')[2]}" feature_sparsity_file = os.listdir(artifact_dir)[0] - self.feature_sparsity = torch.load(f"{artifact_dir}/{feature_sparsity_file}") - + self.feature_sparsity = torch.load( + f"{artifact_dir}/{feature_sparsity_file}" + ) + self.sae_path = path_to_artifact - else: + else: assert sae_path is not None self.sae_path = sae_path - + if init_session: self.init_sae_session() - + self.n_features_at_a_time = n_features_at_a_time self.max_batch_size = max_batch_size self.buffer_tokens = buffer_tokens self.use_wandb = use_wandb - self.final_index = final_index if final_index is not None else self.sparse_autoencoder.cfg.d_sae + self.final_index = ( + final_index + if final_index is not None + else self.sparse_autoencoder.cfg.d_sae + ) self.n_batches_to_sample_from = n_batches_to_sample_from self.n_prompts_to_select = n_prompts_to_select - - + # Deal with file structure if not os.path.exists(dashboard_parent_folder): os.makedirs(dashboard_parent_folder) - self.dashboard_folder = f"{dashboard_parent_folder}/{self.get_dashboard_folder_name()}" + self.dashboard_folder = ( + f"{dashboard_parent_folder}/{self.get_dashboard_folder_name()}" + ) if not os.path.exists(self.dashboard_folder): os.makedirs(self.dashboard_folder) - + if not continue_existing_dashboard: # check if there are files there and if so abort if len(os.listdir(self.dashboard_folder)) > 0: raise ValueError("Dashboard folder not empty. Aborting.") def get_feature_sparsity_path(self, wandb_artifact_path): - prefix = wandb_artifact_path.split(':')[0] + prefix = wandb_artifact_path.split(":")[0] return f"{prefix}_log_feature_sparsity:v9" - + def get_dashboard_folder_name(self): - model = self.sparse_autoencoder.cfg.model_name hook_point = self.sparse_autoencoder.cfg.hook_point d_sae = self.sparse_autoencoder.cfg.d_sae dashboard_folder_name = f"{model}_{hook_point}_{d_sae}" - + return dashboard_folder_name - + def init_sae_session(self): - - self.model, self.sparse_autoencoder, self.activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained( - self.sae_path - ) - - def get_tokens(self, n_batches_to_sample_from = 2**12, n_prompts_to_select = 4096*6): - ''' + ( + self.model, + self.sparse_autoencoder, + self.activation_store, + ) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path) + + def get_tokens( + self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6 + ): + """ Get the tokens needed for dashboard generation. - ''' - + """ + all_tokens_list = [] pbar = tqdm(range(n_batches_to_sample_from)) for _ in pbar: - batch_tokens = self.activation_store.get_batch_tokens() - batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][:batch_tokens.shape[0]] + batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][ + : batch_tokens.shape[0] + ] all_tokens_list.append(batch_tokens) - + all_tokens = torch.cat(all_tokens_list, dim=0) all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])] return all_tokens[:n_prompts_to_select] def get_index_to_resume_from(self): - for i in range(self.n_features): - if not os.path.exists(f"{self.dashboard_folder}/data_{i:04}.html"): - break + break n_features = self.sparse_autoencoder.cfg.d_sae n_features_at_a_time = self.n_features_at_a_time @@ -170,91 +179,119 @@ def get_index_to_resume_from(self): n_features_remaining = self.final_index - id_of_last_feature_without_dashboard n_batches_to_do = n_features_remaining // n_features_at_a_time if self.final_index == n_features: - id_to_start_from = max(0, n_features - (n_batches_to_do + 1) * n_features_at_a_time) + id_to_start_from = max( + 0, n_features - (n_batches_to_do + 1) * n_features_at_a_time + ) else: - id_to_start_from = 0 # testing purposes only - - + id_to_start_from = 0 # testing purposes only + print(f"File {i} does not exist") print(f"features left to do: {n_features_remaining}") print(f"id_to_start_from: {id_to_start_from}") - print(f"number of batches to do: {(n_features - id_to_start_from) // n_features_at_a_time}") - + print( + f"number of batches to do: {(n_features - id_to_start_from) // n_features_at_a_time}" + ) + return id_to_start_from - + @torch.no_grad() def get_feature_property_df(self): - - sparse_autoencoder= self.sparse_autoencoder + sparse_autoencoder = self.sparse_autoencoder feature_sparsity = self.feature_sparsity - - W_dec_normalized = sparse_autoencoder.W_dec.cpu()# / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True) - W_enc_normalized = sparse_autoencoder.W_enc.cpu() / sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True) + + W_dec_normalized = ( + sparse_autoencoder.W_dec.cpu() + ) # / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True) + W_enc_normalized = ( + sparse_autoencoder.W_enc.cpu() + / sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True) + ) d_e_projection = cosine_similarity(W_dec_normalized, W_enc_normalized.T) b_dec_projection = sparse_autoencoder.b_dec.cpu() @ W_dec_normalized.T - temp_df = pd.DataFrame({ - "log_feature_sparsity": feature_sparsity + 1e-10, - "d_e_projection": d_e_projection, - # "d_e_projection_normalized": d_e_projection_normalized, - "b_enc": sparse_autoencoder.b_enc.detach().cpu(), - "feature": [f"feature_{i}" for i in range(sparse_autoencoder.cfg.d_sae)], - "index": torch.arange(sparse_autoencoder.cfg.d_sae), - "dead_neuron": (feature_sparsity < -9).cpu(), - }) - + temp_df = pd.DataFrame( + { + "log_feature_sparsity": feature_sparsity + 1e-10, + "d_e_projection": d_e_projection, + # "d_e_projection_normalized": d_e_projection_normalized, + "b_enc": sparse_autoencoder.b_enc.detach().cpu(), + "feature": [ + f"feature_{i}" for i in range(sparse_autoencoder.cfg.d_sae) + ], + "index": torch.arange(sparse_autoencoder.cfg.d_sae), + "dead_neuron": (feature_sparsity < -9).cpu(), + } + ) + return temp_df - - + def run(self): - ''' + """ Generate the dashboard. - ''' - + """ + if self.use_wandb: # get name from wandb - random_suffix= str(uuid.uuid4())[:8] + random_suffix = str(uuid.uuid4())[:8] name = f"{self.get_dashboard_folder_name()}_{random_suffix}" run = wandb.init( project="feature_dashboards", config=self.sparse_autoencoder.cfg, - name = name, - tags = [ + name=name, + tags=[ f"model_{self.sparse_autoencoder.cfg.model_name}", f"hook_point_{self.sparse_autoencoder.cfg.hook_point}", - ] + ], ) - + if self.model is None: self.init_sae_session() - # generate all the plots if self.use_wandb: feature_property_df = self.get_feature_property_df() - - fig = px.histogram(runner.feature_sparsity+1e-10, nbins=100, log_x=False, title="Feature sparsity") - wandb.log({"plots/feature_density_histogram": wandb.Html(plotly.io.to_html(fig))}) - fig = px.histogram(self.sparse_autoencoder.b_enc.detach().cpu(), title = "b_enc", nbins = 100) + fig = px.histogram( + feature_property_df.log_feature_sparsity, + nbins=100, + log_x=False, + title="Feature sparsity", + ) + wandb.log( + {"plots/feature_density_histogram": wandb.Html(plotly.io.to_html(fig))} + ) + + fig = px.histogram( + self.sparse_autoencoder.b_enc.detach().cpu(), title="b_enc", nbins=100 + ) wandb.log({"plots/b_enc_histogram": wandb.Html(plotly.io.to_html(fig))}) - - fig = px.histogram(feature_property_df.d_e_projection, nbins = 100, title = "D/E projection") - wandb.log({"plots/d_e_projection_histogram": wandb.Html(plotly.io.to_html(fig))}) - - fig = px.histogram(self.sparse_autoencoder.b_dec.detach().cpu(), nbins=100, title = "b_dec projection onto W_dec") - wandb.log({"plots/b_dec_projection_histogram": wandb.Html(plotly.io.to_html(fig))}) - - fig = px.scatter_matrix(feature_property_df, - dimensions = ["log_feature_sparsity", "d_e_projection", "b_enc"], + + fig = px.histogram( + feature_property_df.d_e_projection, nbins=100, title="D/E projection" + ) + wandb.log( + {"plots/d_e_projection_histogram": wandb.Html(plotly.io.to_html(fig))} + ) + + fig = px.histogram( + self.sparse_autoencoder.b_dec.detach().cpu(), + nbins=100, + title="b_dec projection onto W_dec", + ) + wandb.log( + {"plots/b_dec_projection_histogram": wandb.Html(plotly.io.to_html(fig))} + ) + + fig = px.scatter_matrix( + feature_property_df, + dimensions=["log_feature_sparsity", "d_e_projection", "b_enc"], color="dead_neuron", hover_name="feature", opacity=0.2, height=800, - width =1400, + width=1400, ) wandb.log({"plots/scatter_matrix": wandb.Html(plotly.io.to_html(fig))}) - self.n_features = self.sparse_autoencoder.cfg.d_sae id_to_start_from = self.get_index_to_resume_from() @@ -264,19 +301,21 @@ def run(self): feature_idx = torch.tensor(range(id_to_start_from, id_to_end_at)) feature_idx = feature_idx.reshape(-1, self.n_features_at_a_time) feature_idx = [x.tolist() for x in feature_idx] - + print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}") print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}") print(f"Writing files to: {self.dashboard_folder}") # get tokens: start = time.time() - tokens = self.get_tokens(self.n_batches_to_sample_from, self.n_prompts_to_select) + tokens = self.get_tokens( + self.n_batches_to_sample_from, self.n_prompts_to_select + ) end = time.time() print(f"Time to get tokens: {end - start}") if self.use_wandb: wandb.log({"time/time_to_get_tokens": end - start}) - + with torch.no_grad(): for interesting_features in tqdm(feature_idx): print(interesting_features) @@ -290,45 +329,51 @@ def run(self): tokens=tokens, feature_idx=interesting_features, max_batch_size=self.max_batch_size, - left_hand_k = 3, - buffer = (self.buffer_tokens, self.buffer_tokens), - n_groups = 10, - first_group_size = 20, - other_groups_size = 5, - verbose = True, + left_hand_k=3, + buffer=(self.buffer_tokens, self.buffer_tokens), + n_groups=10, + first_group_size=20, + other_groups_size=5, + verbose=True, ) - + for i, test_idx in enumerate(feature_data.keys()): html_str = feature_data[test_idx].get_all_html() - with open(f"{self.dashboard_folder}/data_{test_idx:04}.html", "w") as f: + with open( + f"{self.dashboard_folder}/data_{test_idx:04}.html", "w" + ) as f: f.write(html_str) - + if i < 10 and self.use_wandb: # upload the html as an artifact artifact = wandb.Artifact(f"feature_{test_idx}", type="feature") - artifact.add_file(f"{self.dashboard_folder}/data_{test_idx:04}.html") + artifact.add_file( + f"{self.dashboard_folder}/data_{test_idx:04}.html" + ) run.log_artifact(artifact) - + # also upload as html to dashboard wandb.log( - {f"features/feature_dashboard": wandb.Html(f"{self.dashboard_folder}/data_{test_idx:04}.html")}, - step = test_idx - ) - + { + f"features/feature_dashboard": wandb.Html( + f"{self.dashboard_folder}/data_{test_idx:04}.html" + ) + }, + step=test_idx, + ) + # when done zip the folder - shutil.make_archive(self.dashboard_folder, 'zip', self.dashboard_folder) - + shutil.make_archive(self.dashboard_folder, "zip", self.dashboard_folder) + # then upload the zip as an artifact artifact = wandb.Artifact("dashboard", type="zipped_feature_dashboards") artifact.add_file(f"{self.dashboard_folder}.zip") run.log_artifact(artifact) - + # terminate the run run.finish() - + # delete the dashboard folder shutil.rmtree(self.dashboard_folder) - - return - + return diff --git a/sae_analysis/visualizer/README.md b/sae_analysis/visualizer/README.md index 15a9ac58..829f2484 100644 --- a/sae_analysis/visualizer/README.md +++ b/sae_analysis/visualizer/README.md @@ -15,4 +15,4 @@ This particular feature seems to be a fuzzy skip trigram, with the pattern being These visualisations were created using the GELU-1l model from Neel Nanda's HuggingFace library, as well as an autoencoder which he trained on its single layer of neuron activations (see [this Colab](https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn) from Neel). -You can use my [Colab]() to generate more of these visualisations. You can use this [sae visualiser](https://www.perfectlynormal.co.uk/blog-sae) to navigate through the first thousand features of the aforementioned autoencoder. \ No newline at end of file +You can use my [Colab]() to generate more of these visualisations. You can use this [sae visualiser](https://www.perfectlynormal.co.uk/blog-sae) to navigate through the first thousand features of the aforementioned autoencoder. diff --git a/sae_analysis/visualizer/css/general.css b/sae_analysis/visualizer/css/general.css index dcb5f5fc..7b5404c4 100644 --- a/sae_analysis/visualizer/css/general.css +++ b/sae_analysis/visualizer/css/general.css @@ -25,4 +25,4 @@ table { } code { all: unset; -} \ No newline at end of file +} diff --git a/sae_analysis/visualizer/css/sequences.css b/sae_analysis/visualizer/css/sequences.css index b45b7f7f..bcef8d6e 100644 --- a/sae_analysis/visualizer/css/sequences.css +++ b/sae_analysis/visualizer/css/sequences.css @@ -129,6 +129,3 @@ table code { width: 50%; margin-right: -4px; } - - - diff --git a/sae_analysis/visualizer/css/tables.css b/sae_analysis/visualizer/css/tables.css index ffda1234..fdf82c6e 100644 --- a/sae_analysis/visualizer/css/tables.css +++ b/sae_analysis/visualizer/css/tables.css @@ -14,4 +14,4 @@ h4 { } .code-bold code { font-weight: bold; -} \ No newline at end of file +} diff --git a/sae_analysis/visualizer/data_fns.py b/sae_analysis/visualizer/data_fns.py index 446158c7..c6b1c486 100644 --- a/sae_analysis/visualizer/data_fns.py +++ b/sae_analysis/visualizer/data_fns.py @@ -46,8 +46,8 @@ class HistogramData: - ''' - Class for storing all the data necessary to construct a histogram (because e.g. + """ + Class for storing all the data necessary to construct a histogram (because e.g. for a vector with length `d_vocab`, we don't need to store it all!). This is initialised with a tensor of data, and it automatically calculates & stores @@ -55,9 +55,9 @@ class HistogramData: This isn't a dataclass, because the things we hold at the end are not the same as the things we start with! - ''' - def __init__(self, data: Tensor, n_bins: int, tickmode: str): + """ + def __init__(self, data: Tensor, n_bins: int, tickmode: str): if data.numel() == 0: self.bar_heights = [] self.bar_values = [] @@ -74,7 +74,7 @@ def __init__(self, data: Tensor, n_bins: int, tickmode: str): # calculate the heights of each bin bar_heights = torch.histc(data, bins=n_bins) bar_values = bin_edges[:-1] + bin_size / 2 - + # choose tickvalues (super hacky and terrible, should improve this) assert tickmode in ["ints", "5 ticks"] @@ -92,9 +92,13 @@ def __init__(self, data: Tensor, n_bins: int, tickmode: str): num_negative_ticks = 3 num_positive_ticks = int(max_value / tickrange) tick_vals = merge_lists( - reversed([-tickrange * i for i in range(1, 1+num_negative_ticks)]), # negative values (if exist) - [0], # zero (always is a tick) - [tickrange * i for i in range(1, 1+num_positive_ticks)] # positive values + reversed( + [-tickrange * i for i in range(1, 1 + num_negative_ticks)] + ), # negative values (if exist) + [0], # zero (always is a tick) + [ + tickrange * i for i in range(1, 1 + num_positive_ticks) + ], # positive values ) self.bar_heights = bar_heights.tolist() @@ -102,10 +106,9 @@ def __init__(self, data: Tensor, n_bins: int, tickmode: str): self.tick_vals = tick_vals - @dataclass class SequenceData: - ''' + """ Class to store data for a given sequence, which will be turned into a JavaScript visulisation. Before hover: @@ -119,7 +122,8 @@ class SequenceData: top5_logit_changes: list of the corresponding 5 changes in logits for those tokens bottom5_str_tokens: list of the bottom 5 logit-boosted tokens by this feature bottom5_logit_changes: list of the corresponding 5 changes in logits for those tokens - ''' + """ + token_ids: List[str] feat_acts: List[float] contribution_to_loss: List[float] @@ -134,64 +138,70 @@ def __len__(self): def __str__(self): return f"SequenceData({''.join(self.token_ids)})" - + def __post_init__(self): - '''Filters down the data, by deleting the "on hover" information if the activations are zero.''' - self.top5_logit_contributions, self.top5_token_ids = self._filter(self.top5_logit_contributions, self.top5_token_ids) - self.bottom5_logit_contributions, self.bottom5_token_ids = self._filter(self.bottom5_logit_contributions, self.bottom5_token_ids) + """Filters down the data, by deleting the "on hover" information if the activations are zero.""" + self.top5_logit_contributions, self.top5_token_ids = self._filter( + self.top5_logit_contributions, self.top5_token_ids + ) + self.bottom5_logit_contributions, self.bottom5_token_ids = self._filter( + self.bottom5_logit_contributions, self.bottom5_token_ids + ) def _filter(self, float_list: List[List[float]], int_list: List[List[str]]): float_list = [[f for f in floats if f != 0] for floats in float_list] - int_list = [[i for i, f in zip(ints, floats)] for ints, floats in zip(int_list, float_list)] + int_list = [ + [i for i, f in zip(ints, floats)] + for ints, floats in zip(int_list, float_list) + ] return float_list, int_list - class SequenceDataBatch: - ''' + """ Class to store a list of SequenceData objects at once, by passing in tensors or objects with an extra dimension at the start. Note, I'll be creating these objects by passing in objects which are either 2D (k seq_len) or 3D (k seq_len top5), but which are all lists (of strings/ints/floats). - ''' + """ + def __init__(self, **kwargs): self.seqs = [ SequenceData( - token_ids = kwargs["token_ids"][k], - feat_acts = kwargs["feat_acts"][k], - contribution_to_loss = kwargs["contribution_to_loss"][k], - repeat = kwargs["repeat"], - top5_token_ids = kwargs["top5_token_ids"][k], - top5_logit_contributions = kwargs["top5_logit_contributions"][k], - bottom5_token_ids = kwargs["bottom5_token_ids"][k], - bottom5_logit_contributions = kwargs["bottom5_logit_contributions"][k], + token_ids=kwargs["token_ids"][k], + feat_acts=kwargs["feat_acts"][k], + contribution_to_loss=kwargs["contribution_to_loss"][k], + repeat=kwargs["repeat"], + top5_token_ids=kwargs["top5_token_ids"][k], + top5_logit_contributions=kwargs["top5_logit_contributions"][k], + bottom5_token_ids=kwargs["bottom5_token_ids"][k], + bottom5_logit_contributions=kwargs["bottom5_logit_contributions"][k], ) for k in range(len(kwargs["token_ids"])) ] def __getitem__(self, idx: int) -> SequenceData: return self.seqs[idx] - + def __len__(self) -> int: return len(self.seqs) - + def __str__(self) -> str: return "\n".join([str(seq) for seq in self.seqs]) - @dataclass class FeatureData: - ''' + """ Class to store all data for a feature that will be used in the visualization. Also has a bunch of methods to create visualisations. So this is the main important class. The biggest arg is `sequence_data`, it's explained below. The other args are individual, and are used to construct the left-hand visualisations. - + Args for the right-hand sequences: sequence_data: Dict[str, SequenceDataBatch] @@ -199,9 +209,9 @@ class FeatureData: Each key is a group name (there are 12 in total: top, bottom, 10 quantiles), and each value is a SequenceDataBatch object (i.e. it contains a batch of SequenceData objects, one for each sequence in the group). See these classes for more on how these are used. - + Args for the middle column: - + top10_logits: Tuple[TopK, TopK] Contains the most neg / pos 10 logits, used for the logits table @@ -215,17 +225,17 @@ class FeatureData: Also used for frequencies histogram, this is the fraction of activations which are non-zero Args for the left-hand column - + neuron_alignment: Tuple[TopK, Tensor] first element is the topk aligned neurons (i.e. argmax on decoder weights) second element is the fraction of L1 norm this neuron makes up, in this decoder weight vector. - + neurons_correlated: Tuple[TopK, TopK] the topk neurons most correlated with each other, i.e. this feature has (N,) activations and the neurons have (d_mlp, N) activations on these tokens, where N = batch_size * seq_len, and - we find the neuron (column of second tensor) with highest correlation. Contains Pearson & + we find the neuron (column of second tensor) with highest correlation. Contains Pearson & Cosine sim (difference is that Pearson centers weights first). - + b_features_correlated: Tuple[TopK, TopK] same datatype as neurons_correlated, but now we're looking at this feature's (N,) activations and comparing them to the (h, N) activations of the encoder-B features (where h is the hidden @@ -235,7 +245,7 @@ class FeatureData: model: HookedTransformer The thing you're actually doing forward passes through, and finding features of - + encoder: AutoEncoder The encoder of the model, which you're using to find features @@ -245,9 +255,10 @@ class FeatureData: n_groups, first_group_size, other_groups_size All params to determine size of the sequences in right hand of visualisation. - ''' + """ + sequence_data: Dict[str, SequenceDataBatch] - + top10_logits: Tuple[TopK, TopK] logits_histogram_data: HistogramData frequencies_histogram_data: HistogramData @@ -263,25 +274,26 @@ class FeatureData: first_group_size: int = 20 other_groups_size: int = 5 - def return_save_dict(self) -> dict: - '''Returns a dict we use for saving (pickling).''' - return { - k: v for k, v in self.__dict__.items() - if k not in ["vocab_dict"] - } - + """Returns a dict we use for saving (pickling).""" + return {k: v for k, v in self.__dict__.items() if k not in ["vocab_dict"]} @classmethod def load_from_save_dict(self, save_dict, vocab_dict): - '''Loads this object from a dict (e.g. from a pickle file).''' + """Loads this object from a dict (e.g. from a pickle file).""" return FeatureData(**save_dict, vocab_dict=vocab_dict) - @classmethod - def save_batch(cls, batch: Dict[int, "FeatureData"], filename: str, save_type: Literal["pkl", "gzip"]) -> None: - '''Saves a batch of FeatureData objects to a pickle file.''' - assert "." not in filename, "You should pass in the filename without the extension." + def save_batch( + cls, + batch: Dict[int, "FeatureData"], + filename: str, + save_type: Literal["pkl", "gzip"], + ) -> None: + """Saves a batch of FeatureData objects to a pickle file.""" + assert ( + "." not in filename + ), "You should pass in the filename without the extension." filename = filename + ".pkl" if (save_type == "pkl") else filename + ".pkl.gz" save_obj = {k: v.return_save_dict() for k, v in batch.items()} if save_type == "pkl": @@ -291,13 +303,22 @@ def save_batch(cls, batch: Dict[int, "FeatureData"], filename: str, save_type: L with gzip.open(filename, "wb") as f: pickle.dump(save_obj, f) return filename - @classmethod - def load_batch(cls, filename: str, save_type: Literal["pkl", "gzip"], vocab_dict: Dict[int, str], feature_idx: Optional[int] = None) -> Union["FeatureData", Dict[int, "FeatureData"]]: - '''Loads a batch of FeatureData objects from a pickle file.''' - assert "." not in filename, "You should pass in the filename without the extension." - filename = filename + ".pkl" if save_type.startswith("pkl") else filename + ".pkl.gz" + def load_batch( + cls, + filename: str, + save_type: Literal["pkl", "gzip"], + vocab_dict: Dict[int, str], + feature_idx: Optional[int] = None, + ) -> Union["FeatureData", Dict[int, "FeatureData"]]: + """Loads a batch of FeatureData objects from a pickle file.""" + assert ( + "." not in filename + ), "You should pass in the filename without the extension." + filename = ( + filename + ".pkl" if save_type.startswith("pkl") else filename + ".pkl.gz" + ) if save_type.startswith("pkl"): with open(filename, "rb") as f: save_obj = pickle.load(f) @@ -306,14 +327,18 @@ def load_batch(cls, filename: str, save_type: Literal["pkl", "gzip"], vocab_dict save_obj = pickle.load(f) if feature_idx is None: - return {k: FeatureData.load_from_save_dict(v, vocab_dict) for k, v in save_obj.items()} - else: + return { + k: FeatureData.load_from_save_dict(v, vocab_dict) + for k, v in save_obj.items() + } + else: return FeatureData.load_from_save_dict(save_obj[feature_idx], vocab_dict) - def save(self, filename: str, save_type: Literal["pkl", "gzip"]) -> None: - '''Saves this object to a pickle file (we don't need to save the model and encoder too, just the data).''' - assert "." not in filename, "You should pass in the filename without the extension." + """Saves this object to a pickle file (we don't need to save the model and encoder too, just the data).""" + assert ( + "." not in filename + ), "You should pass in the filename without the extension." filename = filename + ".pkl" if (save_type == "pkl") else filename + ".pkl.gz" save_obj = self.return_save_dict() if save_type.startswith("pkl"): @@ -324,37 +349,35 @@ def save(self, filename: str, save_type: Literal["pkl", "gzip"]) -> None: pickle.dump(save_obj, f) return filename - def __str__(self) -> str: num_sequences = sum([len(batch) for batch in self.sequence_data.values()]) return f"FeatureData(num_sequences={num_sequences})" - def get_sequences_html(self) -> str: - sequences_html_dict = {} for group_name, sequences in self.sequence_data.items(): - - full_html = f'

{group_name}

' # style="padding-left:25px;" - + full_html = f"

{group_name}

" # style="padding-left:25px;" + for seq in sequences: html_output = generate_seq_html( self.vocab_dict, - token_ids = seq.token_ids, - feat_acts = seq.feat_acts, - contribution_to_loss = seq.contribution_to_loss, - bold_idx = self.buffer[0], # e.g. the 6th item, with index 5, if buffer=(5, 5) - is_repeat = seq.repeat, - pos_ids = seq.top5_token_ids, - neg_ids = seq.bottom5_token_ids, - pos_val = seq.top5_logit_contributions, - neg_val = seq.bottom5_logit_contributions, + token_ids=seq.token_ids, + feat_acts=seq.feat_acts, + contribution_to_loss=seq.contribution_to_loss, + bold_idx=self.buffer[ + 0 + ], # e.g. the 6th item, with index 5, if buffer=(5, 5) + is_repeat=seq.repeat, + pos_ids=seq.top5_token_ids, + neg_ids=seq.bottom5_token_ids, + pos_val=seq.top5_logit_contributions, + neg_val=seq.bottom5_logit_contributions, ) full_html += html_output - + sequences_html_dict[group_name] = full_html - + # Now, wrap all the values of this dictionary into grid-items: (top, groups of 3 for middle, bottom) html_top, html_bottom, *html_sampled = sequences_html_dict.values() sequences_html = "" @@ -367,53 +390,52 @@ def get_sequences_html(self) -> str: return sequences_html + HTML_HOVERTEXT_SCRIPT - def get_tables_html(self) -> Tuple[str, str]: - bottom10_logits, top10_logits = self.top10_logits # Get the negative and positive background values (darkest when equals max abs). Easier when in tensor form - max_value = max(np.absolute(bottom10_logits.values).max(), np.absolute(top10_logits.values).max()) + max_value = max( + np.absolute(bottom10_logits.values).max(), + np.absolute(top10_logits.values).max(), + ) neg_bg_values = np.absolute(bottom10_logits.values) / max_value pos_bg_values = np.absolute(top10_logits.values) / max_value - + # Generate the html left_tables_html, logit_tables_html = generate_tables_html( - neuron_alignment_indices = self.neuron_alignment[0].indices.tolist(), - neuron_alignment_values = self.neuron_alignment[0].values.tolist(), - neuron_alignment_l1 = self.neuron_alignment[1].tolist(), - correlated_neurons_indices = self.neurons_correlated[0].indices.tolist(), - correlated_neurons_pearson = self.neurons_correlated[0].values.tolist(), - correlated_neurons_l1 = self.neurons_correlated[1].values.tolist(), - correlated_features_indices = None, #self.b_features_correlated[0].indices.tolist(), - correlated_features_pearson = None,#self.b_features_correlated[0].values.tolist(), - correlated_features_l1 = None,#self.b_features_correlated[1].values.tolist(), - + neuron_alignment_indices=self.neuron_alignment[0].indices.tolist(), + neuron_alignment_values=self.neuron_alignment[0].values.tolist(), + neuron_alignment_l1=self.neuron_alignment[1].tolist(), + correlated_neurons_indices=self.neurons_correlated[0].indices.tolist(), + correlated_neurons_pearson=self.neurons_correlated[0].values.tolist(), + correlated_neurons_l1=self.neurons_correlated[1].values.tolist(), + correlated_features_indices=None, # self.b_features_correlated[0].indices.tolist(), + correlated_features_pearson=None, # self.b_features_correlated[0].values.tolist(), + correlated_features_l1=None, # self.b_features_correlated[1].values.tolist(), neg_str=to_str_tokens(self.vocab_dict, bottom10_logits.indices), neg_values=bottom10_logits.values.tolist(), neg_bg_values=neg_bg_values, pos_str=to_str_tokens(self.vocab_dict, top10_logits.indices), pos_values=top10_logits.values.tolist(), - pos_bg_values=pos_bg_values + pos_bg_values=pos_bg_values, ) # Return both items (we'll be wrapping them in 'grid-item' later) return left_tables_html, logit_tables_html - def get_histograms(self) -> Tuple[str, str]: - ''' + """ From the histogram data, returns the actual histogram HTML strings. - ''' - frequencies_histogram, logits_histogram = generate_histograms(self.frequencies_histogram_data, self.logits_histogram_data) + """ + frequencies_histogram, logits_histogram = generate_histograms( + self.frequencies_histogram_data, self.logits_histogram_data + ) return ( f"

ACTIVATIONS
DENSITY = {self.frac_nonzero:.3%}

{frequencies_histogram}
", - f"
{logits_histogram}
" + f"
{logits_histogram}
", ) - def get_all_html(self, debug: bool = False, split_scripts: bool = False) -> str: - # Get the individual HTML left_tables_html, logit_tables_html = self.get_tables_html() sequences_html = self.get_sequences_html() @@ -439,7 +461,7 @@ def get_all_html(self, debug: bool = False, split_scripts: bool = False) -> str: """ # idk why this bug is here, for representing newlines the wrong way html_string = html_string.replace("Ċ", "\n") - + if debug: display(HTML(html_string)) @@ -448,14 +470,10 @@ def get_all_html(self, debug: bool = False, split_scripts: bool = False) -> str: return scripts, html_string else: return html_string - - - - class BatchedCorrCoef: - ''' + """ This class allows me to calculate corrcoef (both Pearson and cosine sim) between two batches of vectors without needing to store them all in memory. @@ -472,7 +490,8 @@ class BatchedCorrCoef: denom = (n * x2_sum - x_sum ** 2) ** 0.5 * (n * y2_sum - y_sum ** 2) ** 0.5 ...and all these quantities (x_sum, xy_sum, etc) can be tracked on a rolling basis. - ''' + """ + def __init__(self): self.n = 0 self.x_sum = 0 @@ -481,34 +500,44 @@ def __init__(self): self.x2_sum = 0 self.y2_sum = 0 - def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): + def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa assert x.ndim == 2 and y.ndim == 2, "Both x and y should be 2D" - assert x.shape[-1] == y.shape[-1], "x and y should have the same size in the last dimension" - + assert ( + x.shape[-1] == y.shape[-1] + ), "x and y should have the same size in the last dimension" + self.n += x.shape[-1] self.x_sum += einops.reduce(x, "X N -> X", "sum") self.y_sum += einops.reduce(y, "Y N -> Y", "sum") self.xy_sum += einops.einsum(x, y, "X N, Y N -> X Y") - self.x2_sum += einops.reduce(x ** 2, "X N -> X", "sum") - self.y2_sum += einops.reduce(y ** 2, "Y N -> Y", "sum") + self.x2_sum += einops.reduce(x**2, "X N -> X", "sum") + self.y2_sum += einops.reduce(y**2, "Y N -> Y", "sum") - def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: + def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: # noqa cossim_numer = self.xy_sum cossim_denom = torch.sqrt(torch.outer(self.x2_sum, self.y2_sum)) + 1e-6 cossim = cossim_numer / cossim_denom pearson_numer = self.n * self.xy_sum - torch.outer(self.x_sum, self.y_sum) - pearson_denom = torch.sqrt(torch.outer(self.n * self.x2_sum - self.x_sum ** 2, self.n * self.y2_sum - self.y_sum ** 2)) + 1e-6 + pearson_denom = ( + torch.sqrt( + torch.outer( + self.n * self.x2_sum - self.x_sum**2, + self.n * self.y2_sum - self.y_sum**2, + ) + ) + + 1e-6 + ) pearson = pearson_numer / pearson_denom return pearson, cossim def topk(self, k: int, largest: bool = True) -> Tuple[TopK, TopK]: - '''Returns the topk corrcoefs, using Pearson (and taking this over the y-tensor)''' + """Returns the topk corrcoefs, using Pearson (and taking this over the y-tensor)""" pearson, cossim = self.corrcoef() X, Y = cossim.shape # Get pearson topk by actually taking topk - pearson_topk = TopK(pearson.topk(dim=-1, k=k, largest=largest)) # shape (X, k) + pearson_topk = TopK(pearson.topk(dim=-1, k=k, largest=largest)) # shape (X, k) # Get cossim topk by indexing into cossim with the indices of the pearson topk: cossim[X, pearson_indices[X, k]] cossim_values = eindex(cossim, pearson_topk.indices, "X [X k]") cossim_topk = TopK((cossim_values, pearson_topk.indices)) @@ -523,19 +552,17 @@ def get_feature_data( hook_point: str, hook_point_layer: int, hook_point_head_index: Optional[int], - tokens: Int[Tensor, "batch seq"], + tokens: Int[Tensor, "batch seq"], # noqa feature_idx: Union[int, List[int]], max_batch_size: Optional[int] = None, - left_hand_k: int = 3, buffer: Tuple[int, int] = (5, 5), n_groups: int = 10, first_group_size: int = 20, other_groups_size: int = 5, verbose: bool = False, - ) -> Dict[int, FeatureData]: - ''' + """ Gets data that will be used to create the sequences in the HTML visualisation. Args: @@ -551,14 +578,15 @@ def get_feature_data( The number of tokens on either side of the feature, for the right-hand visualisation. Returns object of class FeatureData (see that class's docstring for more info). - ''' + """ t0 = time.time() model.reset_hooks(including_permanent=True) device = model.cfg.device # Make feature_idx a list, for convenience - if isinstance(feature_idx, int): feature_idx = [feature_idx] + if isinstance(feature_idx, int): + feature_idx = [feature_idx] n_feats = len(feature_idx) # Chunk the tokens, for less memory usage @@ -574,25 +602,35 @@ def get_feature_data( # corrcoef_encoder_B = BatchedCorrCoef() # Get encoder & decoder directions - feature_act_dir = encoder.W_enc[:, feature_idx] # (d_in, feats) - feature_bias = encoder.b_enc[feature_idx] # (feats,) - feature_out_dir = encoder.W_dec[feature_idx] # (feats, d_in) - + feature_act_dir = encoder.W_enc[:, feature_idx] # (d_in, feats) + feature_bias = encoder.b_enc[feature_idx] # (feats,) + feature_out_dir = encoder.W_dec[feature_idx] # (feats, d_in) + if "resid_pre" in hook_point: - feature_mlp_out_dir = feature_out_dir # (feats, d_model) + feature_mlp_out_dir = feature_out_dir # (feats, d_model) elif "resid_post" in hook_point: - feature_mlp_out_dir = feature_out_dir @ model.W_out[hook_point_layer] # (feats, d_model) + feature_mlp_out_dir = ( + feature_out_dir @ model.W_out[hook_point_layer] + ) # (feats, d_model) elif "hook_q" in hook_point: # unembed proj onto residual stream - feature_mlp_out_dir = feature_out_dir @ model.W_Q[hook_point_layer, hook_point_head_index].T # (feats, d_model)ß - assert feature_act_dir.T.shape == feature_out_dir.shape == (len(feature_idx), encoder.cfg.d_in) + feature_mlp_out_dir = ( + feature_out_dir @ model.W_Q[hook_point_layer, hook_point_head_index].T + ) # (feats, d_model)ß + assert ( + feature_act_dir.T.shape + == feature_out_dir.shape + == (len(feature_idx), encoder.cfg.d_in) + ) t1 = time.time() # ! Define hook function to perform feature ablation - def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint): - ''' + def hook_fn_act_post( + act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint # noqa + ): # noqa + """ Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where: - f_i are the feature activations - d_i are the feature output directions @@ -601,10 +639,12 @@ def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint if we did this, then we'd have to run a different fwd pass for every feature, which is super wasteful! But later, we'll calculate the effect of feature ablation, i.e. x^j <- x^j - f_i(x^j)d_i for i = feature_idx, only on the tokens we care about (the ones which will appear in the visualisation). - ''' + """ # Calculate & store the feature activations (we need to store them so we can get the right-hand visualisations later) x_cent = act_post - encoder.b_dec - feat_acts_pre = einops.einsum(x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats") + feat_acts_pre = einops.einsum( + x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats" + ) feat_acts = F.relu(feat_acts_pre + feature_bias) all_feat_acts.append(feat_acts) @@ -613,7 +653,7 @@ def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint einops.rearrange(feat_acts, "batch seq feats -> feats (batch seq)"), einops.rearrange(act_post, "batch seq d_mlp -> d_mlp (batch seq)"), ) - + # Calculate encoder-B feature activations (we don't need to store them, cause it's just for the left-hand visualisations) # x_cent_B = act_post - encoder_B.b_dec # feat_acts_pre_B = einops.einsum(x_cent_B, encoder_B.W_enc, "batch seq d_mlp, d_mlp d_hidden -> batch seq d_hidden") @@ -624,11 +664,13 @@ def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint # einops.rearrange(feat_acts, "batch seq feats -> feats (batch seq)"), # einops.rearrange(feat_acts_B, "batch seq d_hidden -> d_hidden (batch seq)"), # ) - - def hook_fn_query(hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint): - ''' - - Replace act_post with projection of query onto the resid by W_k^T. + + def hook_fn_query( + hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint # noqa + ): + """ + + Replace act_post with projection of query onto the resid by W_k^T. Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where: - f_i are the feature activations - d_i are the feature output directions @@ -637,75 +679,105 @@ def hook_fn_query(hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPo if we did this, then we'd have to run a different fwd pass for every feature, which is super wasteful! But later, we'll calculate the effect of feature ablation, i.e. x^j <- x^j - f_i(x^j)d_i for i = feature_idx, only on the tokens we care about (the ones which will appear in the visualisation). - ''' + """ # Calculate & store the feature activations (we need to store them so we can get the right-hand visualisations later) hook_q = hook_q[:, :, hook_point_head_index] x_cent = hook_q - encoder.b_dec - feat_acts_pre = einops.einsum(x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats") + feat_acts_pre = einops.einsum( + x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats" + ) feat_acts = F.relu(feat_acts_pre + feature_bias) all_feat_acts.append(feat_acts) - - # project this back up to resid stream size. + + # project this back up to resid stream size. act_resid_proj = hook_q @ model.W_Q[hook_point_layer, hook_point_head_index].T # Update the CorrCoef object between feature activation & neurons corrcoef_neurons.update( einops.rearrange(feat_acts, "batch seq feats -> feats (batch seq)"), - einops.rearrange(act_resid_proj, "batch seq d_model -> d_model (batch seq)"), + einops.rearrange( + act_resid_proj, "batch seq d_model -> d_model (batch seq)" + ), ) - - def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint): - ''' + def hook_fn_resid_post( + resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint # noqa + ): + """ This hook function stores the residual activations, which we'll need later on to calculate the effect of feature ablation. - ''' + """ all_resid_post.append(resid_post) - # Run the model without hook (to store all the information we need, not to actually return anything) - + # ! Run the forward passes (triggering the hooks), concat all results iterator = tqdm(all_tokens, desc="Storing model activations") if "resid_pre" in hook_point: for _tokens in iterator: - model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[ - (hook_point, hook_fn_act_post), - (utils.get_act_name("resid_pre", hook_point_layer), hook_fn_resid_post) - ]) - # If we are using MLP activations, then we'd want this one. + model.run_with_hooks( + _tokens, + return_type=None, + fwd_hooks=[ + (hook_point, hook_fn_act_post), + ( + utils.get_act_name("resid_pre", hook_point_layer), + hook_fn_resid_post, + ), + ], + ) + # If we are using MLP activations, then we'd want this one. elif "resid_post" in hook_point: for _tokens in iterator: - - model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[ - (utils.get_act_name("post", hook_point_layer), hook_fn_act_post), - (utils.get_act_name("resid_post", hook_point_layer), hook_fn_resid_post) - ]) + model.run_with_hooks( + _tokens, + return_type=None, + fwd_hooks=[ + (utils.get_act_name("post", hook_point_layer), hook_fn_act_post), + ( + utils.get_act_name("resid_post", hook_point_layer), + hook_fn_resid_post, + ), + ], + ) elif "hook_q" in hook_point: iterator = tqdm(all_tokens, desc="Storing model activations") for _tokens in iterator: - model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[ - (hook_point, hook_fn_query), - (utils.get_act_name("resid_post", hook_point_layer), hook_fn_resid_post) - ]) - - + model.run_with_hooks( + _tokens, + return_type=None, + fwd_hooks=[ + (hook_point, hook_fn_query), + ( + utils.get_act_name("resid_post", hook_point_layer), + hook_fn_resid_post, + ), + ], + ) + t2 = time.time() # Stack the results, and check shapes (remember that we don't get loss for the last token) - feat_acts = torch.concatenate(all_feat_acts) # [batch seq feats] - resid_post = torch.concatenate(all_resid_post) # [batch seq d_model] + feat_acts = torch.concatenate(all_feat_acts) # [batch seq feats] + resid_post = torch.concatenate(all_resid_post) # [batch seq d_model] assert feat_acts[:, :-1].shape == tokens[:, :-1].shape + (len(feature_idx),) t3 = time.time() - - # ! Calculate all data for the left-hand column visualisations, i.e. the 3 size-3 tables # First, get the logits of this feature - logits = einops.einsum(feature_mlp_out_dir, model.W_U, "feats d_model, d_model d_vocab -> feats d_vocab") + logits = einops.einsum( + feature_mlp_out_dir, + model.W_U, + "feats d_model, d_model d_vocab -> feats d_vocab", + ) # Second, get the neurons most aligned with this feature (based on output weights) - top3_neurons_aligned = TopK(feature_out_dir.topk(dim=-1, k=left_hand_k, largest=True)) - pct_of_l1 = np.absolute(top3_neurons_aligned.values) / feature_out_dir.abs().sum(dim=-1, keepdim=True).cpu().numpy() + top3_neurons_aligned = TopK( + feature_out_dir.topk(dim=-1, k=left_hand_k, largest=True) + ) + pct_of_l1 = ( + np.absolute(top3_neurons_aligned.values) + / feature_out_dir.abs().sum(dim=-1, keepdim=True).cpu().numpy() + ) # Third, get the neurons most correlated with this feature (based on input weights) top_correlations_neurons = corrcoef_neurons.topk(k=left_hand_k, largest=True) # Lastly, get most correlated weights in B features @@ -713,8 +785,6 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo t4 = time.time() - - # ! Calculate all data for the right-hand visualisations, i.e. the sequences # TODO - parallelize this (it could probably be sped up by batching indices & doing all sequences at once, although those would be large tensors) # We do this in 2 steps: @@ -725,24 +795,35 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo sequence_data_list = [] - iterator = range(n_feats) if not(verbose) else tqdm(range(n_feats), desc="Getting sequence data", leave=False) + iterator = ( + range(n_feats) + if not (verbose) + else tqdm(range(n_feats), desc="Getting sequence data", leave=False) + ) for feat in iterator: + _feat_acts = feat_acts[..., feat] # [batch seq] - _feat_acts = feat_acts[..., feat] # [batch seq] - # (1) indices_dict = { - f"TOP ACTIVATIONS
MAX = {_feat_acts.max():.3f}": k_largest_indices(_feat_acts, k=first_group_size, largest=True), - f"BOTTOM ACTIVATIONS
MIN = {_feat_acts.min():.3f}": k_largest_indices(_feat_acts, k=first_group_size, largest=False), + f"TOP ACTIVATIONS
MAX = {_feat_acts.max():.3f}": k_largest_indices( + _feat_acts, k=first_group_size, largest=True + ), + f"BOTTOM ACTIVATIONS
MIN = {_feat_acts.min():.3f}": k_largest_indices( + _feat_acts, k=first_group_size, largest=False + ), } - quantiles = torch.linspace(0, _feat_acts.max(), n_groups+1) - for i in range(n_groups-1, -1, -1): - lower, upper = quantiles[i:i+2] + quantiles = torch.linspace(0, _feat_acts.max(), n_groups + 1) + for i in range(n_groups - 1, -1, -1): + lower, upper = quantiles[i : i + 2] pct = ((_feat_acts >= lower) & (_feat_acts <= upper)).float().mean() - indices = random_range_indices(_feat_acts, (lower, upper), k=other_groups_size) - indices_dict[f"INTERVAL {lower:.3f} - {upper:.3f}
CONTAINS {pct:.3%}"] = indices + indices = random_range_indices( + _feat_acts, (lower, upper), k=other_groups_size + ) + indices_dict[ + f"INTERVAL {lower:.3f} - {upper:.3f}
CONTAINS {pct:.3%}" + ] = indices # Concat all the indices together (in the next steps we do all groups at once) indices_full = torch.concat(list(indices_dict.values())) @@ -753,35 +834,59 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo # i.e. indices[..., 0] = shape (g, buf) contains the batch indices of the sequences, and indices[..., 1] = contains seq indices # (B) index into all our tensors to get the relevant data (this includes calculating the effect of ablation) # (C) construct the SequenceData objects, in the form of a SequenceDataBatch object - + # (A) # For each token index [batch, seq], we actually want [[batch, seq-buffer[0]], ..., [batch, seq], ..., [batch, seq+buffer[1]]] # We get one extra dimension at the start, because we need to see the effect on loss of the first token - buffer_tensor = torch.arange(-buffer[0] - 1, buffer[1] + 1, device=indices_full.device) - indices_full = einops.repeat(indices_full, "g two -> g buf two", buf=buffer[0] + buffer[1] + 2) - indices_full = torch.stack([indices_full[..., 0], indices_full[..., 1] + buffer_tensor], dim=-1).cpu() + buffer_tensor = torch.arange( + -buffer[0] - 1, buffer[1] + 1, device=indices_full.device + ) + indices_full = einops.repeat( + indices_full, "g two -> g buf two", buf=buffer[0] + buffer[1] + 2 + ) + indices_full = torch.stack( + [indices_full[..., 0], indices_full[..., 1] + buffer_tensor], dim=-1 + ).cpu() # (B) # Template for indexing is new_tensor[k, seq] = tensor[indices_full[k, seq, 1], indices_full[k, seq, 2]], sometimes there's an extra dim at the end tokens_group = eindex(tokens, indices_full[:, 1:], "[g buf 0] [g buf 1]") feat_acts_group = eindex(_feat_acts, indices_full, "[g buf 0] [g buf 1]") - resid_post_group = eindex(resid_post, indices_full, "[g buf 0] [g buf 1] d_model") + resid_post_group = eindex( + resid_post, indices_full, "[g buf 0] [g buf 1] d_model" + ) # From these feature activations, get the actual contribution to the final value of the residual stream - resid_post_feature_effect = einops.einsum(feat_acts_group, feature_mlp_out_dir[feat], "g buf, d_model -> g buf d_model") + resid_post_feature_effect = einops.einsum( + feat_acts_group, + feature_mlp_out_dir[feat], + "g buf, d_model -> g buf d_model", + ) # Get the resulting new logits (by subtracting this effect from resid_post, then applying layernorm & unembedding) new_resid_post = resid_post_group - resid_post_feature_effect - new_logits = (new_resid_post / new_resid_post.std(dim=-1, keepdim=True)) @ model.W_U - orig_logits = (resid_post_group / resid_post_group.std(dim=-1, keepdim=True)) @ model.W_U + new_logits = ( + new_resid_post / new_resid_post.std(dim=-1, keepdim=True) + ) @ model.W_U + orig_logits = ( + resid_post_group / resid_post_group.std(dim=-1, keepdim=True) + ) @ model.W_U # Get the top5 & bottom5 changes in logits # note - changes in logits are for hovering over predict-ING token, so it should align w/ tokens_group, hence we slice [:, 1:] - contribution_to_logprobs = orig_logits.log_softmax(dim=-1) - new_logits.log_softmax(dim=-1) - top5_contribution_to_logits = TopK(contribution_to_logprobs[:, :-1].topk(k=5, largest=True)) - bottom5_contribution_to_logits = TopK(contribution_to_logprobs[:, :-1].topk(k=5, largest=False)) + contribution_to_logprobs = orig_logits.log_softmax( + dim=-1 + ) - new_logits.log_softmax(dim=-1) + top5_contribution_to_logits = TopK( + contribution_to_logprobs[:, :-1].topk(k=5, largest=True) + ) + bottom5_contribution_to_logits = TopK( + contribution_to_logprobs[:, :-1].topk(k=5, largest=False) + ) # Get the change in loss (which is negative of change of logprobs for correct token) # note - changes in loss are for underlining predict-ED token, hence we slice [:, :-1] - contribution_to_loss = eindex(-contribution_to_logprobs[:, :-1], tokens_group, "g buf [g buf]") + contribution_to_loss = eindex( + -contribution_to_logprobs[:, :-1], tokens_group, "g buf [g buf]" + ) # (C) # Now that we've indexed everything, construct the batch of SequenceData objects @@ -790,14 +895,22 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo for group_name, indices in indices_dict.items(): lower, upper = g_total, g_total + len(indices) sequence_data[group_name] = SequenceDataBatch( - token_ids=tokens_group[lower: upper].tolist(), - feat_acts=feat_acts_group[lower: upper, 1:].tolist(), - contribution_to_loss=contribution_to_loss[lower: upper].tolist(), + token_ids=tokens_group[lower:upper].tolist(), + feat_acts=feat_acts_group[lower:upper, 1:].tolist(), + contribution_to_loss=contribution_to_loss[lower:upper].tolist(), repeat=False, - top5_token_ids=top5_contribution_to_logits.indices[lower: upper].tolist(), - top5_logit_contributions=top5_contribution_to_logits.values[lower: upper].tolist(), - bottom5_token_ids=bottom5_contribution_to_logits.indices[lower: upper].tolist(), - bottom5_logit_contributions=bottom5_contribution_to_logits.values[lower: upper].tolist(), + top5_token_ids=top5_contribution_to_logits.indices[ + lower:upper + ].tolist(), + top5_logit_contributions=top5_contribution_to_logits.values[ + lower:upper + ].tolist(), + bottom5_token_ids=bottom5_contribution_to_logits.indices[ + lower:upper + ].tolist(), + bottom5_logit_contributions=bottom5_contribution_to_logits.values[ + lower:upper + ].tolist(), ) g_total += len(indices) @@ -806,7 +919,6 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo t5 = time.time() - # ! Get all data for the middle column visualisations, i.e. the two histograms & the logit table nonzero_feat_acts = [] @@ -816,7 +928,6 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo logits_histogram_data = [] for feat in range(n_feats): - _logits = logits[feat] # Get data for logits histogram @@ -824,39 +935,43 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo logits_histogram_data.append(HistogramData(_logits, n_bins=40, tickmode="ints")) # Get data for logits table - top10_logits.append((TopK(_logits.topk(k=10, largest=False)), TopK(_logits.topk(k=10)))) + top10_logits.append( + (TopK(_logits.topk(k=10, largest=False)), TopK(_logits.topk(k=10))) + ) # Get data for feature activations histogram _feat_acts = feat_acts[..., feat] nonzero_feat_acts = _feat_acts[_feat_acts > 0] frac_nonzero.append(nonzero_feat_acts.numel() / _feat_acts.numel()) - frequencies_histogram_data.append(HistogramData(nonzero_feat_acts, n_bins=40, tickmode="ints")) + frequencies_histogram_data.append( + HistogramData(nonzero_feat_acts, n_bins=40, tickmode="ints") + ) t6 = time.time() - # ! Return the output, as a dict of FeatureData items vocab_dict = model.tokenizer.vocab - vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()} + vocab_dict = { + v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items() + } return_obj = { feat: FeatureData( - # For right-hand sequences sequence_data=sequence_data_list[i], - # For middle column (logits table, and both histograms) top10_logits=top10_logits[i], logits_histogram_data=logits_histogram_data[i], frequencies_histogram_data=frequencies_histogram_data[i], frac_nonzero=frac_nonzero[i], - # For left column, i.e. the 3 tables of size 3 neuron_alignment=(top3_neurons_aligned[i], pct_of_l1[i]), - neurons_correlated=(top_correlations_neurons[0][i], top_correlations_neurons[1][i]), - b_features_correlated=None,#(top_correlations_encoder_B[0][i], top_correlations_encoder_B[1][i]), - + neurons_correlated=( + top_correlations_neurons[0][i], + top_correlations_neurons[1][i], + ), + b_features_correlated=None, # (top_correlations_encoder_B[0][i], top_correlations_encoder_B[1][i]), # Other stuff (not containing data) vocab_dict=vocab_dict, buffer=buffer, @@ -867,41 +982,55 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo for i, feat in enumerate(feature_idx) } - # ! If verbose, try to estimate time it will take to generate data for all features, plus storage space if verbose: - n_feats_total = encoder.cfg.d_sae # Get time total_time = t5 - t0 table = Table("Task", "Time", "Pct %", title="Time taken for each task") for task, _time in zip( - ["Setup code", "Fwd passes", "Concats", "Left-hand tables", "Right-hand sequences", "Middle column"], - [t1-t0, t2-t1, t3-t2, t4-t3, t5-t4, t6-t5] + [ + "Setup code", + "Fwd passes", + "Concats", + "Left-hand tables", + "Right-hand sequences", + "Middle column", + ], + [t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4, t6 - t5], ): frac = _time / total_time table.add_row(task, f"{_time:.2f}s", f"{frac:.1%}") rprint(table) - est = ((t3 - t0) + (n_feats_total / n_feats) * (t6 - t4) / 60) + est = (t3 - t0) + (n_feats_total / n_feats) * (t6 - t4) / 60 print(f"Estimated time for all {n_feats_total} features = {est:.0f} minutes\n") # Get filesizes, for different methods of saving batch_size = 50 if n_feats >= batch_size: - print(f"Estimated filesize of all {n_feats_total} features if saved in groups of batch_size, with save type...") - save_obj = {k: v for k, v in return_obj.items() if k in feature_idx[:batch_size]} + print( + f"Estimated filesize of all {n_feats_total} features if saved in groups of batch_size, with save type..." + ) + save_obj = { + k: v for k, v in return_obj.items() if k in feature_idx[:batch_size] + } filename = str(Path(__file__).parent.resolve() / "temp") for save_type in ["pkl", "gzip"]: t0 = time.time() - full_filename = FeatureData.save_batch(save_obj, filename=filename, save_type=save_type) + full_filename = FeatureData.save_batch( + save_obj, filename=filename, save_type=save_type + ) t1 = time.time() - loaded_obj = FeatureData.load_batch(filename, save_type=save_type, vocab_dict=vocab_dict) + loaded_obj = FeatureData.load_batch( + filename, save_type=save_type, vocab_dict=vocab_dict + ) t2 = time.time() filesize = os.path.getsize(full_filename) / 1e6 - print(f"{save_type:>5} = {filesize * n_feats_total / batch_size:>5.1f} MB, save time = {t1-t0:.3f}s, load time = {t2-t1:.3f}s") + print( + f"{save_type:>5} = {filesize * n_feats_total / batch_size:>5.1f} MB, save time = {t1-t0:.3f}s, load time = {t2-t1:.3f}s" + ) os.remove(full_filename) return return_obj - diff --git a/sae_analysis/visualizer/html/frequency_histogram.html b/sae_analysis/visualizer/html/frequency_histogram.html index 27baa464..2d5d17ef 100644 --- a/sae_analysis/visualizer/html/frequency_histogram.html +++ b/sae_analysis/visualizer/html/frequency_histogram.html @@ -2,4 +2,4 @@
- \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/hovertext_script.html b/sae_analysis/visualizer/html/hovertext_script.html index 70d682be..642da429 100644 --- a/sae_analysis/visualizer/html/hovertext_script.html +++ b/sae_analysis/visualizer/html/hovertext_script.html @@ -23,4 +23,4 @@ }); }); - \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/logit_table_template.html b/sae_analysis/visualizer/html/logit_table_template.html index fad88115..cba07816 100644 --- a/sae_analysis/visualizer/html/logit_table_template.html +++ b/sae_analysis/visualizer/html/logit_table_template.html @@ -38,4 +38,4 @@

POSITIVE LOGITS

- \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/logits_histogram.html b/sae_analysis/visualizer/html/logits_histogram.html index 9e0589ea..a88ee6b3 100644 --- a/sae_analysis/visualizer/html/logits_histogram.html +++ b/sae_analysis/visualizer/html/logits_histogram.html @@ -110,4 +110,4 @@ // }); - \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/token_template.html b/sae_analysis/visualizer/html/token_template.html index 9048dea6..49b7a7c7 100644 --- a/sae_analysis/visualizer/html/token_template.html +++ b/sae_analysis/visualizer/html/token_template.html @@ -29,4 +29,4 @@ - \ No newline at end of file + diff --git a/sae_analysis/visualizer/html_fns.py b/sae_analysis/visualizer/html_fns.py index 5d3955af..3ece35ac 100644 --- a/sae_analysis/visualizer/html_fns.py +++ b/sae_analysis/visualizer/html_fns.py @@ -8,20 +8,22 @@ from sae_analysis.visualizer.utils_fns import to_str_tokens -''' +""" Key feature of these functions: the arguments should be descriptive of their role in the actual HTML visualisation. If the arguments are super arcane features of the model data, this is bad! -''' +""" ROOT_DIR = Path(__file__).parent CSS_DIR = Path(__file__).parent / "css" -CSS = "\n".join([ - (CSS_DIR / "general.css").read_text(), - (CSS_DIR / "sequences.css").read_text(), - (CSS_DIR / "tables.css").read_text(), -]) +CSS = "\n".join( + [ + (CSS_DIR / "general.css").read_text(), + (CSS_DIR / "sequences.css").read_text(), + (CSS_DIR / "tables.css").read_text(), + ] +) HTML_DIR = Path(__file__).parent / "html" HTML_TOKEN = (HTML_DIR / "token_template.html").read_text() @@ -32,35 +34,32 @@ HTML_HOVERTEXT_SCRIPT = (HTML_DIR / "hovertext_script.html").read_text() -BG_COLOR_MAP = colors.LinearSegmentedColormap.from_list("bg_color_map", ["white", "darkorange"]) +BG_COLOR_MAP = colors.LinearSegmentedColormap.from_list( + "bg_color_map", ["white", "darkorange"] +) def generate_tok_html( vocab_dict: dict, - this_token: str, underline_color: str, bg_color: str, is_bold: bool = False, - feat_act: float = 0.0, contribution_to_loss: float = 0.0, pos_ids: List[int] = [0, 0, 0, 0, 0], pos_val: List[float] = [0.0, 0.0, 0.0, 0.0, 0.0], neg_ids: List[int] = [0, 0, 0, 0, 0], neg_val: List[float] = [0.0, 0.0, 0.0, 0.0, 0.0], - - ): - ''' + """ Creates a single sequence visualisation, by reading from the `token_template.html` file. Currently, a bunch of things are randomly chosen rather than actually calculated (we're going for proof of concept here). - ''' + """ html_output = ( - HTML_TOKEN - .replace("this_token", to_str_tokens(vocab_dict, this_token)) + HTML_TOKEN.replace("this_token", to_str_tokens(vocab_dict, this_token)) .replace("feat_activation", f"{feat_act:+.3f}") .replace("feature_ablation", f"{contribution_to_loss:+.3f}") .replace("font_weight", "bold" if is_bold else "normal") @@ -70,7 +69,7 @@ def generate_tok_html( # Figure out if the activations were zero on previous token, i.e. no predictions were affected is_empty = len(pos_ids) + len(neg_ids) == 0 - + # Get the string tokens pos_str = [to_str_tokens(vocab_dict, i) for i in pos_ids] neg_str = [to_str_tokens(vocab_dict, i) for i in neg_ids] @@ -80,31 +79,51 @@ def generate_tok_html( neg_str.extend([""] * 5) pos_val.extend([0.0] * 5) neg_val.extend([0.0] * 5) - + # Make all the substitutions - html_output = re.sub("pos_str_(\d)", lambda m: pos_str[int(m.group(1))].replace(" ", " "), html_output) - html_output = re.sub("neg_str_(\d)", lambda m: neg_str[int(m.group(1))].replace(" ", " "), html_output) - html_output = re.sub("pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output) - html_output = re.sub("neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output) + html_output = re.sub( + "pos_str_(\d)", + lambda m: pos_str[int(m.group(1))].replace(" ", " "), + html_output, + ) + html_output = re.sub( + "neg_str_(\d)", + lambda m: neg_str[int(m.group(1))].replace(" ", " "), + html_output, + ) + html_output = re.sub( + "pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output + ) + html_output = re.sub( + "neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output + ) # If the effect on loss is nothing (because feature isn't active), replace the HTML output with smth saying this if is_empty: html_output = ( - html_output - .replace('
', '