Skip to content

Commit

Permalink
Added support for non-tokenized datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
lucyfarnik committed Dec 2, 2023
1 parent d06e09b commit afcc239
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 12 deletions.
25 changes: 14 additions & 11 deletions sae_training/activations_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@

class DataLoaderBuffer:
def __init__(
self, cfg, model: HookedTransformer, data_path="NeelNanda/c4-code-tokenized-2b"
self, cfg, model: HookedTransformer,
data_path="NeelNanda/c4-code-tokenized-2b",
is_dataset_tokenized=True,
):
self.cfg = cfg
self.model = model
self.data_path = data_path
self.is_dataset_tokenized = is_dataset_tokenized
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.iterable_dataset = iter(self.dataset)
self.buffer = torch.zeros(0, self.cfg.d_in, device=self.cfg.device)
Expand All @@ -37,16 +40,16 @@ def get_batch_tokens(self):

# pbar = tqdm(total=batch_size, desc="Filling batches")
while batch_tokens.shape[0] < batch_size:
# if not pretokenized:
# s = next(dataset)["text"]
# tokens = model.to_tokens(s, truncate=False, move_to_device=True).squeeze(0)
# assert len(tokens.shape) == 1, f"tokens.shape should be 1D but was {tokens.shape}"
# else:
tokens = torch.tensor(
next(self.iterable_dataset)["tokens"],
dtype=torch.long,
device=device,
)
if not self.is_dataset_tokenized:
s = next(self.iterable_dataset)["text"]
tokens = self.model.to_tokens(s, truncate=False, move_to_device=True).squeeze(0)
assert len(tokens.shape) == 1, f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
next(self.iterable_dataset)["tokens"],
dtype=torch.long,
device=device,
)
token_len = tokens.shape[0]

while token_len > 0:
Expand Down
3 changes: 2 additions & 1 deletion sae_training/lm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LanguageModelSAERunnerConfig:
hook_point: str = "blocks.0.hook_mlp_out"
hook_point_layer: int = 0
dataset_path: str = "NeelNanda/c4-tokenized-2b"
is_dataset_tokenized: bool = True

# SAE Parameters
d_in: int = 512
Expand Down Expand Up @@ -69,7 +70,7 @@ def language_model_sae_runner(cfg):

# initialize dataset
activations_buffer = DataLoaderBuffer(
cfg, model, data_path=cfg.dataset_path
cfg, model, data_path=cfg.dataset_path, is_dataset_tokenized=cfg.is_dataset_tokenized,
)

# initialize the SAE
Expand Down
48 changes: 48 additions & 0 deletions tests/benchmark/test_language_model_sae_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_language_model_sae_runner_mlp_out():
hook_point_layer = 0,
d_in = 512,
dataset_path = "NeelNanda/c4-tokenized-2b",
is_dataset_tokenized=True,

# SAE Parameters
expansion_factor = 64, # determines the dimension of the SAE.
Expand Down Expand Up @@ -66,6 +67,7 @@ def test_language_model_sae_runner_resid_pre():
hook_point_layer = 0,
d_in = 512,
dataset_path = "NeelNanda/c4-tokenized-2b",
is_dataset_tokenized=True,

# SAE Parameters
expansion_factor = 32, # determines the dimension of the SAE.
Expand Down Expand Up @@ -98,3 +100,49 @@ def test_language_model_sae_runner_resid_pre():
assert trained_sae is not None


def test_language_model_sae_runner_not_tokenized():

cfg = LanguageModelSAERunnerConfig(

# Data Generating Function (Model + Training Distibuion)
model_name = "gelu-2l",
hook_point = "blocks.1.hook_mlp_out",
hook_point_layer = 0,
d_in = 512,
dataset_path = "roneneldan/TinyStories",
is_dataset_tokenized=False,

# SAE Parameters
expansion_factor = 64, # determines the dimension of the SAE.

# Training Parameters
lr = 1e-4,
l1_coefficient = 1e-4,
train_batch_size = 4096,
context_size = 128,

# Activation Store Parameters
n_batches_in_buffer = 8,
total_training_tokens = 25_000_00 * 60,
store_batch_size = 32,

# Resampling protocol
feature_sampling_window = 1000,
feature_reinit_scale = 0.2,
dead_feature_threshold = 1e-8,

# WANDB
log_to_wandb = True,
wandb_project= "mats_sae_training_language_models",
wandb_entity = None,

# Misc
device = "mps",
seed = 42,
checkpoint_path = "checkpoints",
dtype = torch.float32,
)

trained_sae = language_model_sae_runner(cfg)

assert trained_sae is not None

0 comments on commit afcc239

Please sign in to comment.