From afcc239de1aa91961a14e5734644d2ffaf4b764a Mon Sep 17 00:00:00 2001 From: Lucy Farnik Date: Sat, 2 Dec 2023 12:48:35 +0000 Subject: [PATCH] Added support for non-tokenized datasets --- sae_training/activations_buffer.py | 25 +++++----- sae_training/lm_runner.py | 3 +- .../test_language_model_sae_runner.py | 48 +++++++++++++++++++ 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/sae_training/activations_buffer.py b/sae_training/activations_buffer.py index c8016051..28169941 100644 --- a/sae_training/activations_buffer.py +++ b/sae_training/activations_buffer.py @@ -12,11 +12,14 @@ class DataLoaderBuffer: def __init__( - self, cfg, model: HookedTransformer, data_path="NeelNanda/c4-code-tokenized-2b" + self, cfg, model: HookedTransformer, + data_path="NeelNanda/c4-code-tokenized-2b", + is_dataset_tokenized=True, ): self.cfg = cfg self.model = model self.data_path = data_path + self.is_dataset_tokenized = is_dataset_tokenized self.dataset = load_dataset(data_path, split="train", streaming=True) self.iterable_dataset = iter(self.dataset) self.buffer = torch.zeros(0, self.cfg.d_in, device=self.cfg.device) @@ -37,16 +40,16 @@ def get_batch_tokens(self): # pbar = tqdm(total=batch_size, desc="Filling batches") while batch_tokens.shape[0] < batch_size: - # if not pretokenized: - # s = next(dataset)["text"] - # tokens = model.to_tokens(s, truncate=False, move_to_device=True).squeeze(0) - # assert len(tokens.shape) == 1, f"tokens.shape should be 1D but was {tokens.shape}" - # else: - tokens = torch.tensor( - next(self.iterable_dataset)["tokens"], - dtype=torch.long, - device=device, - ) + if not self.is_dataset_tokenized: + s = next(self.iterable_dataset)["text"] + tokens = self.model.to_tokens(s, truncate=False, move_to_device=True).squeeze(0) + assert len(tokens.shape) == 1, f"tokens.shape should be 1D but was {tokens.shape}" + else: + tokens = torch.tensor( + next(self.iterable_dataset)["tokens"], + dtype=torch.long, + device=device, + ) token_len = tokens.shape[0] while token_len > 0: diff --git a/sae_training/lm_runner.py b/sae_training/lm_runner.py index a1f8ec05..4696bf80 100644 --- a/sae_training/lm_runner.py +++ b/sae_training/lm_runner.py @@ -23,6 +23,7 @@ class LanguageModelSAERunnerConfig: hook_point: str = "blocks.0.hook_mlp_out" hook_point_layer: int = 0 dataset_path: str = "NeelNanda/c4-tokenized-2b" + is_dataset_tokenized: bool = True # SAE Parameters d_in: int = 512 @@ -69,7 +70,7 @@ def language_model_sae_runner(cfg): # initialize dataset activations_buffer = DataLoaderBuffer( - cfg, model, data_path=cfg.dataset_path + cfg, model, data_path=cfg.dataset_path, is_dataset_tokenized=cfg.is_dataset_tokenized, ) # initialize the SAE diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index d1a61ce2..3d7e6e55 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -17,6 +17,7 @@ def test_language_model_sae_runner_mlp_out(): hook_point_layer = 0, d_in = 512, dataset_path = "NeelNanda/c4-tokenized-2b", + is_dataset_tokenized=True, # SAE Parameters expansion_factor = 64, # determines the dimension of the SAE. @@ -66,6 +67,7 @@ def test_language_model_sae_runner_resid_pre(): hook_point_layer = 0, d_in = 512, dataset_path = "NeelNanda/c4-tokenized-2b", + is_dataset_tokenized=True, # SAE Parameters expansion_factor = 32, # determines the dimension of the SAE. @@ -98,3 +100,49 @@ def test_language_model_sae_runner_resid_pre(): assert trained_sae is not None +def test_language_model_sae_runner_not_tokenized(): + + cfg = LanguageModelSAERunnerConfig( + + # Data Generating Function (Model + Training Distibuion) + model_name = "gelu-2l", + hook_point = "blocks.1.hook_mlp_out", + hook_point_layer = 0, + d_in = 512, + dataset_path = "roneneldan/TinyStories", + is_dataset_tokenized=False, + + # SAE Parameters + expansion_factor = 64, # determines the dimension of the SAE. + + # Training Parameters + lr = 1e-4, + l1_coefficient = 1e-4, + train_batch_size = 4096, + context_size = 128, + + # Activation Store Parameters + n_batches_in_buffer = 8, + total_training_tokens = 25_000_00 * 60, + store_batch_size = 32, + + # Resampling protocol + feature_sampling_window = 1000, + feature_reinit_scale = 0.2, + dead_feature_threshold = 1e-8, + + # WANDB + log_to_wandb = True, + wandb_project= "mats_sae_training_language_models", + wandb_entity = None, + + # Misc + device = "mps", + seed = 42, + checkpoint_path = "checkpoints", + dtype = torch.float32, + ) + + trained_sae = language_model_sae_runner(cfg) + + assert trained_sae is not None