From c0b29cc5037bfeea88888f2bb90ebd2ecdf2ed05 Mon Sep 17 00:00:00 2001 From: Can Rager Date: Thu, 28 Mar 2024 12:24:07 -0400 Subject: [PATCH] add prepend bos flag --- sae_training/activations_store.py | 31 ++++++++------ sae_training/config.py | 68 ++++++++++++++++--------------- sae_training/evals.py | 15 ++++--- 3 files changed, 63 insertions(+), 51 deletions(-) diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 34fdbd8b..c49c25aa 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -105,11 +105,6 @@ def get_batch_tokens(self): # TODO: Fix this so that we are limiting how many tokens we get from the same context. assert self.model.tokenizer is not None # keep pyright happy - bos_token_id_tensor = torch.tensor( - [self.model.tokenizer.bos_token_id], - device=tokens.device, - dtype=torch.long, - ) while token_len > 0 and batch_tokens.shape[0] < batch_size: # Space left in the current batch space_left = context_size - current_length @@ -129,15 +124,21 @@ def get_batch_tokens(self): token_len -= space_left # only add BOS if it's not already the first token - if tokens[0] != bos_token_id_tensor: - tokens = torch.cat( - ( - bos_token_id_tensor, - tokens, - ), - dim=0, + if self.cfg.prepend_bos: + bos_token_id_tensor = torch.tensor( + [self.model.tokenizer.bos_token_id], + device=tokens.device, + dtype=torch.long, ) - token_len += 1 + if tokens[0] != bos_token_id_tensor: + tokens = torch.cat( + ( + bos_token_id_tensor, + tokens, + ), + dim=0, + ) + token_len += 1 current_length = context_size # If a batch is full, concatenate and move to next batch @@ -170,6 +171,7 @@ def get_activations(self, batch_tokens: torch.Tensor): batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1, + prepend_bos=self.cfg.prepend_bos, )[1] activations_list = [layerwise_activations[act_name] for act_name in act_names] if self.cfg.hook_point_head_index is not None: @@ -330,6 +332,7 @@ def _get_next_dataset_tokens(self) -> torch.Tensor: s, truncate=True, move_to_device=True, + prepend_bos=self.cfg.prepend_bos, ).squeeze(0) assert ( len(tokens.shape) == 1 @@ -341,4 +344,6 @@ def _get_next_dataset_tokens(self) -> torch.Tensor: device=device, requires_grad=False, ) + if not self.cfg.prepend_bos and tokens[0] == self.model.tokenizer.bos_token_id: + tokens = tokens[1:] return tokens diff --git a/sae_training/config.py b/sae_training/config.py index ac543a4f..e9ba894d 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -86,6 +86,8 @@ class LanguageModelSAERunnerConfig(RunnerConfig): # Misc n_checkpoints: int = 0 checkpoint_path: str = "checkpoints" + prepend_bos: bool = True + verbose: bool = True def __post_init__(self): super().__post_init__() @@ -114,46 +116,46 @@ def __post_init__(self): ).util.generate_id() # not sure why this type is erroring self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}" - print( - f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" - ) - # Print out some useful info: - n_tokens_per_buffer = ( - self.store_batch_size * self.context_size * self.n_batches_in_buffer - ) - print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}") - n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer - print( - f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}" - ) + if self.verbose: + print( + f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + ) + # Print out some useful info: + n_tokens_per_buffer = ( + self.store_batch_size * self.context_size * self.n_batches_in_buffer + ) + print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}") + n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer + print( + f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}" + ) - total_training_steps = self.total_training_tokens // self.train_batch_size - print(f"Total training steps: {total_training_steps}") + total_training_steps = self.total_training_tokens // self.train_batch_size + print(f"Total training steps: {total_training_steps}") - total_wandb_updates = total_training_steps // self.wandb_log_frequency - print(f"Total wandb updates: {total_wandb_updates}") + total_wandb_updates = total_training_steps // self.wandb_log_frequency + print(f"Total wandb updates: {total_wandb_updates}") - # how many times will we sample dead neurons? - # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window" - n_feature_window_samples = total_training_steps // self.feature_sampling_window - print( - f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}" - ) - print( - f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}" - ) + # how many times will we sample dead neurons? + # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window" + n_feature_window_samples = total_training_steps // self.feature_sampling_window + print( + f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}" + ) + print( + f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}" + ) + print( + f"We will reset the sparsity calculation {n_feature_window_samples} times." + ) + # print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size) + print( + f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}" + ) if self.use_ghost_grads: print("Using Ghost Grads.") - print( - f"We will reset the sparsity calculation {n_feature_window_samples} times." - ) - # print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size) - print( - f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}" - ) - @dataclass class CacheActivationsRunnerConfig(RunnerConfig): diff --git a/sae_training/evals.py b/sae_training/evals.py index e5fb278c..956b43c1 100644 --- a/sae_training/evals.py +++ b/sae_training/evals.py @@ -70,8 +70,7 @@ def run_evals( l2_norm_out = torch.norm(sae_out, dim=-1) l2_norm_ratio = l2_norm_out / l2_norm_in - wandb.log( - { + metrics = { # l2 norms f"metrics/l2_norm{suffix}": l2_norm_out.mean().item(), f"metrics/l2_ratio{suffix}": l2_norm_ratio.mean().item(), @@ -80,9 +79,15 @@ def run_evals( f"metrics/ce_loss_without_sae{suffix}": ntp_loss, f"metrics/ce_loss_with_sae{suffix}": recons_loss, f"metrics/ce_loss_with_ablation{suffix}": zero_abl_loss, - }, - step=n_training_steps, - ) + } + + if wandb.run is not None: + wandb.log( + metrics, + step=n_training_steps + ) + + return metrics # head_index = sparse_autoencoder.cfg.hook_point_head_index