diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 51bc5f6c..dbaf0539 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import os -from typing import Any, Iterator, cast +from typing import Any, Iterator, Literal, TypeVar, cast import torch from datasets import ( @@ -12,6 +14,11 @@ from torch.utils.data import DataLoader from transformer_lens import HookedTransformer +from sae_training.config import ( + CacheActivationsRunnerConfig, + LanguageModelSAERunnerConfig, +) + HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset @@ -21,18 +28,85 @@ class ActivationsStore: while training SAEs. """ + model: HookedTransformer + dataset: HfDataset + cached_activations_path: str | None + tokens_column: Literal["tokens", "input_ids", "text"] + hook_point_head_index: int | None + + @classmethod + def from_config( + cls, + model: HookedTransformer, + cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig, + dataset: HfDataset | None = None, + create_dataloader: bool = True, + ) -> "ActivationsStore": + cached_activations_path = cfg.cached_activations_path + # set cached_activations_path to None if we're not using cached activations + if ( + isinstance(cfg, LanguageModelSAERunnerConfig) + and not cfg.use_cached_activations + ): + cached_activations_path = None + return cls( + model=model, + dataset=dataset or cfg.dataset_path, + hook_point=cfg.hook_point, + hook_point_layers=listify(cfg.hook_point_layer), + hook_point_head_index=cfg.hook_point_head_index, + context_size=cfg.context_size, + d_in=cfg.d_in, + n_batches_in_buffer=cfg.n_batches_in_buffer, + total_training_tokens=cfg.total_training_tokens, + store_batch_size=cfg.store_batch_size, + train_batch_size=cfg.train_batch_size, + prepend_bos=cfg.prepend_bos, + device=cfg.device, + dtype=cfg.dtype, + cached_activations_path=cached_activations_path, + create_dataloader=create_dataloader, + ) + def __init__( self, - cfg: Any, model: HookedTransformer, - dataset: HfDataset | None = None, + dataset: HfDataset | str, + hook_point: str, + hook_point_layers: list[int], + hook_point_head_index: int | None, + context_size: int, + d_in: int, + n_batches_in_buffer: int, + total_training_tokens: int, + store_batch_size: int, + train_batch_size: int, + prepend_bos: bool, + device: str | torch.device, + dtype: torch.dtype, + cached_activations_path: str | None = None, create_dataloader: bool = True, ): - self.cfg = cfg self.model = model - self.dataset = dataset or load_dataset( - cfg.dataset_path, split="train", streaming=True + self.dataset = ( + load_dataset(dataset, split="train", streaming=True) + if isinstance(dataset, str) + else dataset ) + self.hook_point = hook_point + self.hook_point_layers = hook_point_layers + self.hook_point_head_index = hook_point_head_index + self.context_size = context_size + self.d_in = d_in + self.n_batches_in_buffer = n_batches_in_buffer + self.total_training_tokens = total_training_tokens + self.store_batch_size = store_batch_size + self.train_batch_size = train_batch_size + self.prepend_bos = prepend_bos + self.device = device + self.dtype = dtype + self.cached_activations_path = cached_activations_path + self.iterable_dataset = iter(self.dataset) # Check if dataset is tokenized @@ -40,13 +114,13 @@ def __init__( # check if it's tokenized if "tokens" in dataset_sample.keys(): - self.cfg.is_dataset_tokenized = True + self.is_dataset_tokenized = True self.tokens_column = "tokens" elif "input_ids" in dataset_sample.keys(): - self.cfg.is_dataset_tokenized = True + self.is_dataset_tokenized = True self.tokens_column = "input_ids" elif "text" in dataset_sample.keys(): - self.cfg.is_dataset_tokenized = False + self.is_dataset_tokenized = False self.tokens_column = "text" else: raise ValueError( @@ -54,32 +128,32 @@ def __init__( ) self.iterable_dataset = iter(self.dataset) # Reset iterator after checking - if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts - assert self.cfg.cached_activations_path is not None # keep pyright happy + if cached_activations_path is not None: # EDIT: load from multi-layer acts + assert self.cached_activations_path is not None # keep pyright happy # Sanity check: does the cache directory exist? assert os.path.exists( - self.cfg.cached_activations_path - ), f"Cache directory {self.cfg.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names." + self.cached_activations_path + ), f"Cache directory {self.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names." self.next_cache_idx = 0 # which file to open next self.next_idx_within_buffer = 0 # where to start reading from in that file # Check that we have enough data on disk - first_buffer = torch.load(f"{self.cfg.cached_activations_path}/0.pt") + first_buffer = torch.load(f"{self.cached_activations_path}/0.pt") buffer_size_on_disk = first_buffer.shape[0] - n_buffers_on_disk = len(os.listdir(self.cfg.cached_activations_path)) + n_buffers_on_disk = len(os.listdir(self.cached_activations_path)) # Note: we're assuming all files have the same number of tokens # (which seems reasonable imo since that's what our script does) n_activations_on_disk = buffer_size_on_disk * n_buffers_on_disk assert ( - n_activations_on_disk > self.cfg.total_training_tokens - ), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but cfg.total_training_tokens is {self.cfg.total_training_tokens/1e6:.1f}M." + n_activations_on_disk > self.total_training_tokens + ), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but total_training_tokens is {self.total_training_tokens/1e6:.1f}M." # TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF) if create_dataloader: # fill buffer half a buffer, so we can mix it with a new buffer - self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2) + self.storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2) self.dataloader = self.get_data_loader() def get_batch_tokens(self): @@ -87,9 +161,9 @@ def get_batch_tokens(self): Streams a batch of tokens from a dataset. """ - batch_size = self.cfg.store_batch_size - context_size = self.cfg.context_size - device = self.cfg.device + batch_size = self.store_batch_size + context_size = self.context_size + device = self.device batch_tokens = torch.zeros( size=(0, context_size), device=device, dtype=torch.long, requires_grad=False @@ -124,7 +198,7 @@ def get_batch_tokens(self): token_len -= space_left # only add BOS if it's not already the first token - if self.cfg.prepend_bos: + if self.prepend_bos: bos_token_id_tensor = torch.tensor( [self.model.tokenizer.bos_token_id], device=tokens.device, @@ -160,23 +234,19 @@ def get_activations(self, batch_tokens: torch.Tensor): d_in may result from a concatenated head dimension. """ - layers = ( - self.cfg.hook_point_layer - if isinstance(self.cfg.hook_point_layer, list) - else [self.cfg.hook_point_layer] - ) - act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers] + layers = self.hook_point_layers + act_names = [self.hook_point.format(layer=layer) for layer in layers] hook_point_max_layer = max(layers) layerwise_activations = self.model.run_with_cache( batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1, - prepend_bos=self.cfg.prepend_bos, + prepend_bos=self.prepend_bos, )[1] activations_list = [layerwise_activations[act_name] for act_name in act_names] - if self.cfg.hook_point_head_index is not None: + if self.hook_point_head_index is not None: activations_list = [ - act[:, :, self.cfg.hook_point_head_index] for act in activations_list + act[:, :, self.hook_point_head_index] for act in activations_list ] elif activations_list[0].ndim > 3: # if we have a head dimension # flatten the head dimension @@ -190,31 +260,27 @@ def get_activations(self, batch_tokens: torch.Tensor): return stacked_activations def get_buffer(self, n_batches_in_buffer: int): - context_size = self.cfg.context_size - batch_size = self.cfg.store_batch_size - d_in = self.cfg.d_in + context_size = self.context_size + batch_size = self.store_batch_size + d_in = self.d_in total_size = batch_size * n_batches_in_buffer - num_layers = ( - len(self.cfg.hook_point_layer) - if isinstance(self.cfg.hook_point_layer, list) - else 1 - ) # Number of hook points or layers + num_layers = len(self.hook_point_layers) # Number of hook points or layers - if self.cfg.use_cached_activations: + if self.cached_activations_path is not None: # Load the activations from disk buffer_size = total_size * context_size # Initialize an empty tensor with an additional dimension for layers new_buffer = torch.zeros( (buffer_size, num_layers, d_in), - dtype=self.cfg.dtype, - device=self.cfg.device, + dtype=self.dtype, + device=self.device, ) n_tokens_filled = 0 # Assume activations for different layers are stored separately and need to be combined while n_tokens_filled < buffer_size: if not os.path.exists( - f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt" + f"{self.cached_activations_path}/{self.next_cache_idx}.pt" ): print( "\n\nWarning: Ran out of cached activation files earlier than expected." @@ -222,7 +288,7 @@ def get_buffer(self, n_batches_in_buffer: int): print( f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}." ) - if buffer_size % self.cfg.total_training_tokens != 0: + if buffer_size % self.total_training_tokens != 0: print( "This might just be a rounding error — your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens" ) @@ -232,7 +298,7 @@ def get_buffer(self, n_batches_in_buffer: int): return new_buffer activations = torch.load( - f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt" + f"{self.cached_activations_path}/{self.next_cache_idx}.pt" ) taking_subset_of_file = False if n_tokens_filled + activations.shape[0] > buffer_size: @@ -257,8 +323,8 @@ def get_buffer(self, n_batches_in_buffer: int): # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers new_buffer = torch.zeros( (total_size, context_size, num_layers, d_in), - dtype=self.cfg.dtype, - device=self.cfg.device, + dtype=self.dtype, + device=self.device, ) for refill_batch_idx_start in refill_iterator: @@ -286,11 +352,11 @@ def get_data_loader( """ - batch_size = self.cfg.train_batch_size + batch_size = self.train_batch_size # 1. # create new buffer by mixing stored and new buffer mixing_buffer = torch.cat( - [self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer], + [self.get_buffer(self.n_batches_in_buffer // 2), self.storage_buffer], dim=0, ) @@ -325,14 +391,14 @@ def next_batch(self): return next(self.dataloader) def _get_next_dataset_tokens(self) -> torch.Tensor: - device = self.cfg.device - if not self.cfg.is_dataset_tokenized: + device = self.device + if not self.is_dataset_tokenized: s = next(self.iterable_dataset)[self.tokens_column] tokens = self.model.to_tokens( s, truncate=True, move_to_device=True, - prepend_bos=self.cfg.prepend_bos, + prepend_bos=self.prepend_bos, ).squeeze(0) assert ( len(tokens.shape) == 1 @@ -345,8 +411,17 @@ def _get_next_dataset_tokens(self) -> torch.Tensor: requires_grad=False, ) if ( - not self.cfg.prepend_bos + not self.prepend_bos and tokens[0] == self.model.tokenizer.bos_token_id # type: ignore ): tokens = tokens[1:] return tokens + + +T = TypeVar("T") + + +def listify(x: T | list[T]) -> list[T]: + if isinstance(x, list): + return x + return [x] diff --git a/sae_training/cache_activations_runner.py b/sae_training/cache_activations_runner.py index a61dec48..7ec0eac7 100644 --- a/sae_training/cache_activations_runner.py +++ b/sae_training/cache_activations_runner.py @@ -13,16 +13,21 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): model = HookedTransformer.from_pretrained(cfg.model_name) model.to(cfg.device) - activations_store = ActivationsStore(cfg, model, create_dataloader=False) + activations_store = ActivationsStore.from_config( + model, + cfg, + create_dataloader=False, + ) # if the activations directory exists and has files in it, raise an exception - if os.path.exists(activations_store.cfg.cached_activations_path): - if len(os.listdir(activations_store.cfg.cached_activations_path)) > 0: + assert activations_store.cached_activations_path is not None + if os.path.exists(activations_store.cached_activations_path): + if len(os.listdir(activations_store.cached_activations_path)) > 0: raise Exception( - f"Activations directory ({activations_store.cfg.cached_activations_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files." + f"Activations directory ({activations_store.cached_activations_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files." ) else: - os.makedirs(activations_store.cfg.cached_activations_path) + os.makedirs(activations_store.cached_activations_path) print(f"Started caching {cfg.total_training_tokens} activations") tokens_per_buffer = ( @@ -32,7 +37,7 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): # for i in tqdm(range(n_buffers), desc="Caching activations"): for i in range(n_buffers): buffer = activations_store.get_buffer(cfg.n_batches_in_buffer) - torch.save(buffer, f"{activations_store.cfg.cached_activations_path}/{i}.pt") + torch.save(buffer, f"{activations_store.cached_activations_path}/{i}.pt") del buffer if i % cfg.shuffle_every_n_buffers == 0 and i > 0: @@ -41,14 +46,14 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): # Do random pairwise shuffling between the last shuffle_every_n_buffers buffers for _ in range(cfg.n_shuffles_with_last_section): shuffle_activations_pairwise( - activations_store.cfg.cached_activations_path, + activations_store.cached_activations_path, buffer_idx_range=(i - cfg.shuffle_every_n_buffers, i), ) # Do more random pairwise shuffling between all the buffers for _ in range(cfg.n_shuffles_in_entire_dir): shuffle_activations_pairwise( - activations_store.cfg.cached_activations_path, + activations_store.cached_activations_path, buffer_idx_range=(0, i), ) @@ -56,6 +61,6 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): if n_buffers > 1: for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"): shuffle_activations_pairwise( - activations_store.cfg.cached_activations_path, + activations_store.cached_activations_path, buffer_idx_range=(0, n_buffers), ) diff --git a/sae_training/config.py b/sae_training/config.py index 87aaa523..88743ee7 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -1,4 +1,3 @@ -from abc import ABC from dataclasses import dataclass from typing import Any, Optional, cast @@ -8,15 +7,15 @@ @dataclass -class RunnerConfig(ABC): +class LanguageModelSAERunnerConfig: """ - The config that's shared across all runners. + Configuration for training a sparse autoencoder on a language model. """ # Data Generating Function (Model + Training Distibuion) model_name: str = "gelu-2l" hook_point: str = "blocks.{layer}.hook_mlp_out" - hook_point_layer: int = 0 + hook_point_layer: int | list[int] = 0 hook_point_head_index: Optional[int] = None dataset_path: str = "NeelNanda/c4-tokenized-2b" is_dataset_tokenized: bool = True @@ -33,47 +32,39 @@ class RunnerConfig(ABC): n_batches_in_buffer: int = 20 total_training_tokens: int = 2_000_000 store_batch_size: int = 32 + train_batch_size: int = 4096 # Misc device: str | torch.device = "cpu" seed: int = 42 dtype: torch.dtype = torch.float32 - - def __post_init__(self): - # Autofill cached_activations_path unless the user overrode it - if self.cached_activations_path is None: - self.cached_activations_path = f"activations/{self.dataset_path.replace('/', '_')}/{self.model_name.replace('/', '_')}/{self.hook_point}" - if self.hook_point_head_index is not None: - self.cached_activations_path += f"_{self.hook_point_head_index}" - - -@dataclass -class LanguageModelSAERunnerConfig(RunnerConfig): - """ - Configuration for training a sparse autoencoder on a language model. - """ + prepend_bos: bool = True # SAE Parameters b_dec_init_method: str = "geometric_median" - expansion_factor: int = 4 + expansion_factor: int | list[int] = 4 from_pretrained_path: Optional[str] = None d_sae: Optional[int] = None # Training Parameters - l1_coefficient: float = 1e-3 - lp_norm: float = 1 - lr: float = 3e-4 - lr_end: float | None = None # only used for cosine annealing, default is lr / 10 - lr_scheduler_name: str = ( + l1_coefficient: float | list[float] = 1e-3 + lp_norm: float | list[float] = 1 + lr: float | list[float] = 3e-4 + lr_scheduler_name: str | list[str] = ( "constant" # constant, cosineannealing, cosineannealingwarmrestarts ) - lr_warm_up_steps: int = 500 - lr_decay_steps: int = 0 + lr_warm_up_steps: int | list[int] = 500 + lr_end: float | list[float] | None = ( + None # only used for cosine annealing, default is lr / 10 + ) + lr_decay_steps: int | list[int] = 0 + n_restart_cycles: int | list[int] = 1 # used only for cosineannealingwarmrestarts train_batch_size: int = 4096 - n_restart_cycles: int = 1 # only used for cosineannealingwarmrestarts # Resampling protocol args - use_ghost_grads: bool = False # want to change this to true on some timeline. + use_ghost_grads: bool | list[bool] = ( + False # want to change this to true on some timeline. + ) feature_sampling_window: int = 2000 dead_feature_window: int = 1000 # unless this window is larger feature sampling, @@ -89,11 +80,17 @@ class LanguageModelSAERunnerConfig(RunnerConfig): # Misc n_checkpoints: int = 0 checkpoint_path: str = "checkpoints" - prepend_bos: bool = True verbose: bool = True def __post_init__(self): - super().__post_init__() + if self.use_cached_activations and self.cached_activations_path is None: + self.cached_activations_path = _default_cached_activations_path( + self.dataset_path, + self.model_name, + self.hook_point, + self.hook_point_head_index, + ) + if not isinstance(self.expansion_factor, list): self.d_sae = self.d_in * self.expansion_factor self.tokens_per_buffer = ( @@ -112,10 +109,13 @@ def __post_init__(self): "Warning: We are initializing b_dec to zeros. This is probably not what you want." ) - self.device = torch.device(self.device) + self.device: str | torch.device = torch.device(self.device) if self.lr_end is None: - self.lr_end = self.lr / 10 + if isinstance(self.lr, list): + self.lr_end = [lr / 10 for lr in self.lr] + else: + self.lr_end = self.lr / 10 unique_id = cast( Any, wandb @@ -161,16 +161,43 @@ def __post_init__(self): f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}" ) - if self.use_ghost_grads: + if not isinstance(self.use_ghost_grads, list) and self.use_ghost_grads: print("Using Ghost Grads.") @dataclass -class CacheActivationsRunnerConfig(RunnerConfig): +class CacheActivationsRunnerConfig: """ Configuration for caching activations of an LLM. """ + # Data Generating Function (Model + Training Distibuion) + model_name: str = "gelu-2l" + hook_point: str = "blocks.{layer}.hook_mlp_out" + hook_point_layer: int | list[int] = 0 + hook_point_head_index: Optional[int] = None + dataset_path: str = "NeelNanda/c4-tokenized-2b" + is_dataset_tokenized: bool = True + context_size: int = 128 + cached_activations_path: Optional[str] = ( + None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}" + ) + + # SAE Parameters + d_in: int = 512 + + # Activation Store Parameters + n_batches_in_buffer: int = 20 + total_training_tokens: int = 2_000_000 + store_batch_size: int = 32 + train_batch_size: int = 4096 + + # Misc + device: str | torch.device = "cpu" + seed: int = 42 + dtype: torch.dtype = torch.float32 + prepend_bos: bool = True + # Activation caching stuff shuffle_every_n_buffers: int = 10 n_shuffles_with_last_section: int = 10 @@ -178,9 +205,23 @@ class CacheActivationsRunnerConfig(RunnerConfig): n_shuffles_final: int = 100 def __post_init__(self): - super().__post_init__() - if self.use_cached_activations: - # this is a dummy property in this context; only here to avoid class compatibility headaches - raise ValueError( - "use_cached_activations should be False when running cache_activations_runner" + # Autofill cached_activations_path unless the user overrode it + if self.cached_activations_path is None: + self.cached_activations_path = _default_cached_activations_path( + self.dataset_path, + self.model_name, + self.hook_point, + self.hook_point_head_index, ) + + +def _default_cached_activations_path( + dataset_path: str, + model_name: str, + hook_point: str, + hook_point_head_index: int | None, +) -> str: + path = f"activations/{dataset_path.replace('/', '_')}/{model_name.replace('/', '_')}/{hook_point}" + if hook_point_head_index is not None: + path += f"_{hook_point_head_index}" + return path diff --git a/sae_training/evals.py b/sae_training/evals.py index 932b6a9b..179333d7 100644 --- a/sae_training/evals.py +++ b/sae_training/evals.py @@ -21,7 +21,7 @@ def run_evals( suffix: str = "", ) -> Mapping[str, Any]: hook_point = sparse_autoencoder.cfg.hook_point - hook_point_layer = sparse_autoencoder.cfg.hook_point_layer + hook_point_layer = sparse_autoencoder.hook_point_layer hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index ### Evals diff --git a/sae_training/sae_group.py b/sae_training/sae_group.py index 568367f2..bff0950b 100644 --- a/sae_training/sae_group.py +++ b/sae_training/sae_group.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import gzip import os @@ -7,6 +9,7 @@ import torch +from sae_training.config import LanguageModelSAERunnerConfig from sae_training.sparse_autoencoder import SparseAutoencoder @@ -14,12 +17,12 @@ class SAEGroup: autoencoders: list[SparseAutoencoder] - def __init__(self, cfg: Any): + def __init__(self, cfg: LanguageModelSAERunnerConfig): self.cfg = cfg self.autoencoders = [] # This will store tuples of (instance, hyperparameters) self._init_autoencoders(cfg) - def _init_autoencoders(self, cfg: Any): + def _init_autoencoders(self, cfg: LanguageModelSAERunnerConfig): # Dynamically get all combinations of hyperparameters from cfg # Extract all hyperparameter lists from cfg hyperparameters = {k: v for k, v in vars(cfg).items() if isinstance(v, list)} @@ -53,8 +56,9 @@ def to(self, device: torch.device | str): for ae in self.autoencoders: ae.to(device) + # old pickled SAEs load as a dict @classmethod - def load_from_pretrained(cls, path: str): + def load_from_pretrained(cls, path: str) -> "SAEGroup" | dict[str, Any]: """ Load function for the model. Loads the model's state_dict and the config used to train it. This method can be called directly on the class, without needing an instance. @@ -69,7 +73,10 @@ def load_from_pretrained(cls, path: str): try: if torch.backends.mps.is_available(): group = torch.load(path, map_location="mps") - group["cfg"].device = "mps" + if isinstance(group, dict): + group["cfg"].device = "mps" + else: + group.cfg.device = "mps" else: group = torch.load(path) except Exception as e: diff --git a/sae_training/sparse_autoencoder.py b/sae_training/sparse_autoencoder.py index 3a15d60c..425aa084 100644 --- a/sae_training/sparse_autoencoder.py +++ b/sae_training/sparse_autoencoder.py @@ -27,6 +27,14 @@ class ForwardOutput(NamedTuple): class SparseAutoencoder(HookedRootModule): """ """ + l1_coefficient: float + lp_norm: float + d_sae: int + use_ghost_grads: bool + hook_point_layer: int + dtype: torch.dtype + device: str | torch.device + def __init__( self, cfg: LanguageModelSAERunnerConfig, @@ -39,11 +47,25 @@ def __init__( f"d_in must be an int but was {self.d_in=}; {type(self.d_in)=}" ) assert cfg.d_sae is not None # keep pyright happy + # lists are valid only for SAEGroup cfg, not SAE cfg vals + assert not isinstance(cfg.l1_coefficient, list) + assert not isinstance(cfg.lp_norm, list) + assert not isinstance(cfg.lr, list) + assert not isinstance(cfg.lr_scheduler_name, list) + assert not isinstance(cfg.lr_warm_up_steps, list) + assert not isinstance(cfg.use_ghost_grads, list) + assert not isinstance(cfg.hook_point_layer, list) + assert ( + "{layer}" not in cfg.hook_point + ), "{layer} must be replaced with the actual layer number in SAE cfg" + self.d_sae = cfg.d_sae self.l1_coefficient = cfg.l1_coefficient self.lp_norm = cfg.lp_norm self.dtype = cfg.dtype self.device = cfg.device + self.use_ghost_grads = cfg.use_ghost_grads + self.hook_point_layer = cfg.hook_point_layer # NOTE: if using resampling neurons method, you must ensure that we initialise the weights in the order W_enc, b_enc, W_dec, b_dec self.W_enc = nn.Parameter( @@ -107,7 +129,7 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None) ghost_grad_loss = torch.tensor(0.0, dtype=self.dtype, device=self.device) # gate on config and training so evals is not slowed down. if ( - self.cfg.use_ghost_grads + self.use_ghost_grads and self.training and dead_neuron_mask is not None and dead_neuron_mask.sum() > 0 diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index 00e5113f..fe3916c6 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -205,6 +205,23 @@ def _wandb_log_suffix(cfg: Any, hyperparams: Any): def _build_train_context( sae: SparseAutoencoder, total_training_steps: int ) -> SAETrainContext: + assert not isinstance(sae.cfg.lr, list), "lr must not be a list for a single SAE" + assert not isinstance( + sae.cfg.lr_end, list + ), "lr_end must not be a list for a single SAE" + assert not isinstance( + sae.cfg.lr_scheduler_name, list + ), "lr_scheduler_name must not be a list for a single SAE" + assert not isinstance( + sae.cfg.lr_warm_up_steps, list + ), "lr_warm_up_steps must not be a list for a single SAE" + assert not isinstance( + sae.cfg.lr_decay_steps, list + ), "lr_decay_steps must not be a list for a single SAE" + assert not isinstance( + sae.cfg.n_restart_cycles, list + ), "n_restart_cycles must not be a list for a single SAE" + act_freq_scores = torch.zeros( cast(int, sae.cfg.d_sae), device=sae.cfg.device, @@ -246,7 +263,7 @@ def _init_sae_group_b_decs( geometric_medians = {} for sae in sae_group: hyperparams = sae.cfg - sae_layer_id = all_layers.index(hyperparams.hook_point_layer) + sae_layer_id = all_layers.index(sae.hook_point_layer) if hyperparams.b_dec_init_method == "geometric_median": layer_acts = activation_store.storage_buffer.detach()[:, sae_layer_id, :] # get geometric median of the activations if we're using those. @@ -288,8 +305,7 @@ def _train_step( wandb_suffix: str, ) -> TrainStepOutput: assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy - hyperparams = sparse_autoencoder.cfg - layer_id = all_layers.index(hyperparams.hook_point_layer) + layer_id = all_layers.index(sparse_autoencoder.hook_point_layer) sae_in = layer_acts[:, layer_id, :] sparse_autoencoder.train() diff --git a/sae_training/utils.py b/sae_training/utils.py index f407b92a..a96c4de0 100644 --- a/sae_training/utils.py +++ b/sae_training/utils.py @@ -4,6 +4,7 @@ from transformer_lens import HookedTransformer from sae_training.activations_store import ActivationsStore +from sae_training.config import LanguageModelSAERunnerConfig from sae_training.sae_group import SAEGroup from sae_training.sparse_autoencoder import SparseAutoencoder @@ -78,7 +79,7 @@ def get_model(self, model_name: str): return model - def initialize_sparse_autoencoder(self, cfg: Any): + def initialize_sparse_autoencoder(self, cfg: LanguageModelSAERunnerConfig): """ Initializes a sparse autoencoder group, which contains multiple sparse autoencoders """ @@ -87,14 +88,16 @@ def initialize_sparse_autoencoder(self, cfg: Any): return sparse_autoencoder - def get_activations_loader(self, cfg: Any, model: HookedTransformer): + def get_activations_loader( + self, cfg: LanguageModelSAERunnerConfig, model: HookedTransformer + ): """ Loads a DataLoaderBuffer for the activations of a language model. """ - activations_loader = ActivationsStore( - cfg, + activations_loader = ActivationsStore.from_config( model, + cfg, ) return activations_loader diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 200fe225..a4f17541 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -23,7 +23,6 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: use_cached_activations=False, d_in=64, expansion_factor=2, - d_sae=64 * 2, l1_coefficient=2e-3, lp_norm=1, lr=2e-4, diff --git a/tests/unit/test_activations_store.py b/tests/unit/test_activations_store.py index b746c37f..af06661f 100644 --- a/tests/unit/test_activations_store.py +++ b/tests/unit/test_activations_store.py @@ -1,6 +1,5 @@ from collections.abc import Iterable from math import ceil -from types import SimpleNamespace import pytest import torch @@ -8,6 +7,7 @@ from transformer_lens import HookedTransformer from sae_training.activations_store import ActivationsStore +from sae_training.config import LanguageModelSAERunnerConfig from tests.unit.helpers import build_sae_cfg @@ -74,67 +74,27 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: "gpt2", ], ) -def cfg(request: pytest.FixtureRequest) -> SimpleNamespace: +def cfg(request: pytest.FixtureRequest) -> LanguageModelSAERunnerConfig: # This function will be called with each parameter set params = request.param - mock_config = SimpleNamespace() - mock_config.model_name = params["model_name"] - mock_config.dataset_path = params["dataset_path"] - mock_config.is_dataset_tokenized = params["tokenized"] - mock_config.hook_point = params["hook_point"] - mock_config.hook_point_layer = params["hook_point_layer"] - mock_config.d_in = params["d_in"] - mock_config.expansion_factor = 2 - mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor - mock_config.l1_coefficient = 2e-3 - mock_config.lr = 2e-4 - mock_config.train_batch_size = 32 - mock_config.context_size = 16 - mock_config.use_cached_activations = False - mock_config.hook_point_head_index = None - - mock_config.feature_sampling_method = None - mock_config.feature_sampling_window = 50 - mock_config.feature_reinit_scale = 0.1 - mock_config.dead_feature_threshold = 1e-7 - - mock_config.n_batches_in_buffer = 4 - mock_config.total_training_tokens = 1_000_000 - mock_config.store_batch_size = 32 - - mock_config.log_to_wandb = False - mock_config.wandb_project = "test_project" - mock_config.wandb_entity = "test_entity" - mock_config.wandb_log_frequency = 10 - mock_config.device = torch.device("cpu") - mock_config.seed = 24 - mock_config.checkpoint_path = "test/checkpoints" - mock_config.dtype = torch.float32 - mock_config.prepend_bos = params["prepend_bos"] - return mock_config + return build_sae_cfg(**params) @pytest.fixture -def model(cfg: SimpleNamespace): +def model(cfg: LanguageModelSAERunnerConfig): return HookedTransformer.from_pretrained(cfg.model_name, device="cpu") @pytest.fixture -def activation_store(cfg: SimpleNamespace, model: HookedTransformer): - return ActivationsStore(cfg, model) +def activation_store(cfg: LanguageModelSAERunnerConfig, model: HookedTransformer): + return ActivationsStore.from_config(model, cfg) -@pytest.fixture -def activation_store_head_hook( - cfg_head_hook: SimpleNamespace, model: HookedTransformer +def test_activations_store__init__( + cfg: LanguageModelSAERunnerConfig, model: HookedTransformer ): - return ActivationsStore(cfg_head_hook, model) - - -def test_activations_store__init__(cfg: SimpleNamespace, model: HookedTransformer): - store = ActivationsStore(cfg, model) + store = ActivationsStore.from_config(model, cfg) - assert store.cfg == cfg assert store.model == model assert isinstance(store.dataset, IterableDataset) @@ -158,10 +118,10 @@ def test_activations_store__get_batch_tokens(activation_store: ActivationsStore) assert isinstance(batch, torch.Tensor) assert batch.shape == ( - activation_store.cfg.store_batch_size, - activation_store.cfg.context_size, + activation_store.store_batch_size, + activation_store.context_size, ) - assert batch.device == activation_store.cfg.device + assert batch.device == activation_store.device def test_activations_score_get_next_batch( @@ -170,8 +130,8 @@ def test_activations_score_get_next_batch( batch = activation_store.get_batch_tokens() assert batch.shape == ( - activation_store.cfg.store_batch_size, - activation_store.cfg.context_size, + activation_store.store_batch_size, + activation_store.context_size, ) # if model.tokenizer.bos_token_id is not None: @@ -184,10 +144,14 @@ def test_activations_store__get_activations(activation_store: ActivationsStore): batch = activation_store.get_batch_tokens() activations = activation_store.get_activations(batch) - cfg = activation_store.cfg assert isinstance(activations, torch.Tensor) - assert activations.shape == (cfg.store_batch_size, cfg.context_size, 1, cfg.d_in) - assert activations.device == cfg.device + assert activations.shape == ( + activation_store.store_batch_size, + activation_store.context_size, + 1, + activation_store.d_in, + ) + assert activations.device == activation_store.device def test_activations_store__get_activations_head_hook(ts_model: HookedTransformer): @@ -197,26 +161,33 @@ def test_activations_store__get_activations_head_hook(ts_model: HookedTransforme hook_point_layer=1, d_in=4, ) - activation_store_head_hook = ActivationsStore(cfg, ts_model) + activation_store_head_hook = ActivationsStore.from_config(ts_model, cfg) batch = activation_store_head_hook.get_batch_tokens() activations = activation_store_head_hook.get_activations(batch) - cfg = activation_store_head_hook.cfg assert isinstance(activations, torch.Tensor) - assert activations.shape == (cfg.store_batch_size, cfg.context_size, 1, cfg.d_in) - assert activations.device == cfg.device + assert activations.shape == ( + activation_store_head_hook.store_batch_size, + activation_store_head_hook.context_size, + 1, + activation_store_head_hook.d_in, + ) + assert activations.device == activation_store_head_hook.device def test_activations_store__get_buffer(activation_store: ActivationsStore): n_batches_in_buffer = 3 buffer = activation_store.get_buffer(n_batches_in_buffer) - cfg = activation_store.cfg assert isinstance(buffer, torch.Tensor) - buffer_size_expected = cfg.store_batch_size * cfg.context_size * n_batches_in_buffer + buffer_size_expected = ( + activation_store.store_batch_size + * activation_store.context_size + * n_batches_in_buffer + ) - assert buffer.shape == (buffer_size_expected, 1, cfg.d_in) - assert buffer.device == cfg.device + assert buffer.shape == (buffer_size_expected, 1, activation_store.d_in) + assert buffer.device == activation_store.device # 12 is divisible by the length of "hello world", 11 and 13 are not @@ -236,8 +207,8 @@ def test_activations_store__get_batch_tokens__fills_the_context_separated_by_bos context_size=context_size, ) - activation_store = ActivationsStore( - cfg, ts_model, dataset=dataset, create_dataloader=False + activation_store = ActivationsStore.from_config( + ts_model, cfg, dataset=dataset, create_dataloader=False ) encoded_text = tokenize_with_bos(ts_model, "hello world") tokens = activation_store.get_batch_tokens() @@ -264,8 +235,8 @@ def test_activations_store__get_next_dataset_tokens__tokenizes_each_example_in_o {"text": "hello world3"}, ] ) - activation_store = ActivationsStore( - cfg, ts_model, dataset=dataset, create_dataloader=False + activation_store = ActivationsStore.from_config( + ts_model, cfg, dataset=dataset, create_dataloader=False ) assert activation_store._get_next_dataset_tokens().tolist() == tokenize_with_bos( diff --git a/tests/unit/test_sae_group.py b/tests/unit/test_sae_group.py new file mode 100644 index 00000000..fc1734af --- /dev/null +++ b/tests/unit/test_sae_group.py @@ -0,0 +1,44 @@ +from sae_training.sae_group import SAEGroup +from tests.unit.helpers import build_sae_cfg + + +def test_SAEGroup_initializes_all_permutations_of_list_params(): + cfg = build_sae_cfg( + d_in=5, + lr=[0.01, 0.001], + expansion_factor=[2, 4], + ) + sae_group = SAEGroup(cfg) + assert len(sae_group) == 4 + lr_sae_combos = [(ae.cfg.lr, ae.cfg.d_sae) for ae in sae_group] + assert (0.01, 10) in lr_sae_combos + assert (0.01, 20) in lr_sae_combos + assert (0.001, 10) in lr_sae_combos + assert (0.001, 20) in lr_sae_combos + + +def test_SAEGroup_replaces_layer_with_actual_layer(): + cfg = build_sae_cfg( + hook_point="blocks.{layer}.attn.hook_q", + hook_point_layer=5, + ) + sae_group = SAEGroup(cfg) + assert len(sae_group) == 1 + assert sae_group.autoencoders[0].cfg.hook_point == "blocks.5.attn.hook_q" + + +def test_SAEGroup_train_and_eval(): + cfg = build_sae_cfg( + lr=[0.01, 0.001], + expansion_factor=[2, 4], + ) + sae_group = SAEGroup(cfg) + sae_group.train() + for sae in sae_group: + assert sae.training is True + sae_group.eval() + for sae in sae_group: + assert sae.training is False + sae_group.train() + for sae in sae_group: + assert sae.training is True diff --git a/tests/unit/test_sparse_autoencoder.py b/tests/unit/test_sparse_autoencoder.py index 9cf059f5..3d459ea9 100644 --- a/tests/unit/test_sparse_autoencoder.py +++ b/tests/unit/test_sparse_autoencoder.py @@ -228,7 +228,7 @@ def test_SparseAutoencoder_remove_gradient_parallel_to_decoder_directions() -> N grad_delta = orig_grad - sae.W_dec.grad assert torch.nn.functional.cosine_similarity( sae.W_dec.detach(), grad_delta, dim=1 - ).abs() == pytest.approx(1.0, abs=1e-5) + ).abs() == pytest.approx(1.0, abs=1e-3) def test_SparseAutoencoder_get_name_returns_correct_name_from_cfg_vals() -> None: diff --git a/tests/unit/test_train_sae_on_language_model.py b/tests/unit/test_train_sae_on_language_model.py index 3f5b4e58..5c30863a 100644 --- a/tests/unit/test_train_sae_on_language_model.py +++ b/tests/unit/test_train_sae_on_language_model.py @@ -34,6 +34,7 @@ def build_train_ctx( Factory helper to build a default SAETrainContext object. """ assert sae.cfg.d_sae is not None + assert not isinstance(sae.cfg.lr, list) optimizer = torch.optim.Adam(sae.parameters(), lr=sae.cfg.lr) return SAETrainContext( act_freq_scores=( @@ -323,7 +324,7 @@ def test_train_sae_group_on_language_model__runs_and_outputs_look_reasonable( ) # just a tiny datast which will run quickly dataset = Dataset.from_list([{"text": "hello world"}] * 1000) - activation_store = ActivationsStore(cfg, model=ts_model, dataset=dataset) + activation_store = ActivationsStore.from_config(ts_model, cfg, dataset=dataset) sae_group = SAEGroup(cfg) res = train_sae_group_on_language_model( model=ts_model,