Skip to content

Commit

Permalink
Merge pull request #58 from canrager/main
Browse files Browse the repository at this point in the history
Make prepend BOS optional: Default True
  • Loading branch information
jbloomAus authored Mar 28, 2024
2 parents cfafbe7 + 618d4bb commit 48a07f9
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 50 deletions.
34 changes: 21 additions & 13 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ def get_batch_tokens(self):

# TODO: Fix this so that we are limiting how many tokens we get from the same context.
assert self.model.tokenizer is not None # keep pyright happy
bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id],
device=tokens.device,
dtype=torch.long,
)
while token_len > 0 and batch_tokens.shape[0] < batch_size:
# Space left in the current batch
space_left = context_size - current_length
Expand All @@ -129,15 +124,21 @@ def get_batch_tokens(self):
token_len -= space_left

# only add BOS if it's not already the first token
if tokens[0] != bos_token_id_tensor:
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
if self.cfg.prepend_bos:
bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id],
device=tokens.device,
dtype=torch.long,
)
token_len += 1
if tokens[0] != bos_token_id_tensor:
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
)
token_len += 1
current_length = context_size

# If a batch is full, concatenate and move to next batch
Expand Down Expand Up @@ -170,6 +171,7 @@ def get_activations(self, batch_tokens: torch.Tensor):
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
prepend_bos=self.cfg.prepend_bos,
)[1]
activations_list = [layerwise_activations[act_name] for act_name in act_names]
if self.cfg.hook_point_head_index is not None:
Expand Down Expand Up @@ -330,6 +332,7 @@ def _get_next_dataset_tokens(self) -> torch.Tensor:
s,
truncate=True,
move_to_device=True,
prepend_bos=self.cfg.prepend_bos,
).squeeze(0)
assert (
len(tokens.shape) == 1
Expand All @@ -341,4 +344,9 @@ def _get_next_dataset_tokens(self) -> torch.Tensor:
device=device,
requires_grad=False,
)
if (
not self.cfg.prepend_bos
and tokens[0] == self.model.tokenizer.bos_token_id # type: ignore
):
tokens = tokens[1:]
return tokens
70 changes: 37 additions & 33 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
# Misc
n_checkpoints: int = 0
checkpoint_path: str = "checkpoints"
prepend_bos: bool = True
verbose: bool = True

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -114,46 +116,48 @@ def __post_init__(self):
).util.generate_id() # not sure why this type is erroring
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

print(
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
)
# Print out some useful info:
n_tokens_per_buffer = (
self.store_batch_size * self.context_size * self.n_batches_in_buffer
)
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer
print(
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}"
)
if self.verbose:
print(
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
)
# Print out some useful info:
n_tokens_per_buffer = (
self.store_batch_size * self.context_size * self.n_batches_in_buffer
)
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer
print(
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}"
)

total_training_steps = self.total_training_tokens // self.train_batch_size
print(f"Total training steps: {total_training_steps}")
total_training_steps = self.total_training_tokens // self.train_batch_size
print(f"Total training steps: {total_training_steps}")

total_wandb_updates = total_training_steps // self.wandb_log_frequency
print(f"Total wandb updates: {total_wandb_updates}")
total_wandb_updates = total_training_steps // self.wandb_log_frequency
print(f"Total wandb updates: {total_wandb_updates}")

# how many times will we sample dead neurons?
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
n_feature_window_samples = total_training_steps // self.feature_sampling_window
print(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}"
)
# how many times will we sample dead neurons?
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
n_feature_window_samples = (
total_training_steps // self.feature_sampling_window
)
print(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
print(
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}"
)

if self.use_ghost_grads:
print("Using Ghost Grads.")

print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
print(
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}"
)


@dataclass
class CacheActivationsRunnerConfig(RunnerConfig):
Expand Down
6 changes: 3 additions & 3 deletions scripts/run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
" dataset_path=\"NeelNanda/c4-tokenized-2b\",\n",
" is_dataset_tokenized=True,\n",
" # SAE Parameters\n",
" expansion_factor=[16,32,64],\n",
" expansion_factor=[16, 32, 64],\n",
" b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n",
" # Training Parameters\n",
" lr=0.0012,\n",
Expand Down Expand Up @@ -368,7 +368,7 @@
" n_batches_in_buffer=128,\n",
" total_training_tokens=1_000_000 * 20,\n",
" store_batch_size=32,\n",
" feature_sampling_window=500, # So we see the histograms. \n",
" feature_sampling_window=500, # So we see the histograms.\n",
" dead_feature_window=250,\n",
" # WANDB\n",
" log_to_wandb=True,\n",
Expand Down Expand Up @@ -697,7 +697,7 @@
" n_batches_in_buffer=128,\n",
" total_training_tokens=1_000_000 * 20,\n",
" store_batch_size=32,\n",
" feature_sampling_window=500, # So we see the histograms. \n",
" feature_sampling_window=500, # So we see the histograms.\n",
" dead_feature_window=250,\n",
" # WANDB\n",
" log_to_wandb=True,\n",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
seed=24,
checkpoint_path="test/checkpoints",
dtype=torch.float32,
prepend_bos=True,
)

for key, val in kwargs.items():
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/test_activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 64,
"prepend_bos": True,
},
{
"model_name": "tiny-stories-1M",
Expand All @@ -35,6 +36,7 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
"hook_point": "blocks.1.attn.hook_z",
"hook_point_layer": 1,
"d_in": 64,
"prepend_bos": True,
},
{
"model_name": "gelu-2l",
Expand All @@ -43,6 +45,7 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 512,
"prepend_bos": True,
},
{
"model_name": "gpt2",
Expand All @@ -51,6 +54,7 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 768,
"prepend_bos": True,
},
{
"model_name": "gpt2",
Expand All @@ -59,6 +63,7 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
"hook_point": "blocks.1.hook_resid_pre",
"hook_point_layer": 1,
"d_in": 768,
"prepend_bos": True,
},
],
ids=[
Expand Down Expand Up @@ -105,7 +110,7 @@ def cfg(request: pytest.FixtureRequest) -> SimpleNamespace:
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
mock_config.dtype = torch.float32

mock_config.prepend_bos = params["prepend_bos"]
return mock_config


Expand Down

0 comments on commit 48a07f9

Please sign in to comment.