Skip to content

Commit

Permalink
add prepend bos flag
Browse files Browse the repository at this point in the history
  • Loading branch information
canrager committed Mar 28, 2024
1 parent 13c8085 commit c0b29cc
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 51 deletions.
31 changes: 18 additions & 13 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ def get_batch_tokens(self):

# TODO: Fix this so that we are limiting how many tokens we get from the same context.
assert self.model.tokenizer is not None # keep pyright happy
bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id],
device=tokens.device,
dtype=torch.long,
)
while token_len > 0 and batch_tokens.shape[0] < batch_size:
# Space left in the current batch
space_left = context_size - current_length
Expand All @@ -129,15 +124,21 @@ def get_batch_tokens(self):
token_len -= space_left

# only add BOS if it's not already the first token
if tokens[0] != bos_token_id_tensor:
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
if self.cfg.prepend_bos:
bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id],
device=tokens.device,
dtype=torch.long,
)
token_len += 1
if tokens[0] != bos_token_id_tensor:
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
)
token_len += 1
current_length = context_size

# If a batch is full, concatenate and move to next batch
Expand Down Expand Up @@ -170,6 +171,7 @@ def get_activations(self, batch_tokens: torch.Tensor):
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
prepend_bos=self.cfg.prepend_bos,
)[1]
activations_list = [layerwise_activations[act_name] for act_name in act_names]
if self.cfg.hook_point_head_index is not None:
Expand Down Expand Up @@ -330,6 +332,7 @@ def _get_next_dataset_tokens(self) -> torch.Tensor:
s,
truncate=True,
move_to_device=True,
prepend_bos=self.cfg.prepend_bos,
).squeeze(0)
assert (
len(tokens.shape) == 1
Expand All @@ -341,4 +344,6 @@ def _get_next_dataset_tokens(self) -> torch.Tensor:
device=device,
requires_grad=False,
)
if not self.cfg.prepend_bos and tokens[0] == self.model.tokenizer.bos_token_id:
tokens = tokens[1:]
return tokens
68 changes: 35 additions & 33 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
# Misc
n_checkpoints: int = 0
checkpoint_path: str = "checkpoints"
prepend_bos: bool = True
verbose: bool = True

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -114,46 +116,46 @@ def __post_init__(self):
).util.generate_id() # not sure why this type is erroring
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

print(
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
)
# Print out some useful info:
n_tokens_per_buffer = (
self.store_batch_size * self.context_size * self.n_batches_in_buffer
)
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer
print(
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}"
)
if self.verbose:
print(
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
)
# Print out some useful info:
n_tokens_per_buffer = (
self.store_batch_size * self.context_size * self.n_batches_in_buffer
)
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer
print(
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}"
)

total_training_steps = self.total_training_tokens // self.train_batch_size
print(f"Total training steps: {total_training_steps}")
total_training_steps = self.total_training_tokens // self.train_batch_size
print(f"Total training steps: {total_training_steps}")

total_wandb_updates = total_training_steps // self.wandb_log_frequency
print(f"Total wandb updates: {total_wandb_updates}")
total_wandb_updates = total_training_steps // self.wandb_log_frequency
print(f"Total wandb updates: {total_wandb_updates}")

# how many times will we sample dead neurons?
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
n_feature_window_samples = total_training_steps // self.feature_sampling_window
print(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}"
)
# how many times will we sample dead neurons?
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
n_feature_window_samples = total_training_steps // self.feature_sampling_window
print(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
print(
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}"
)

if self.use_ghost_grads:
print("Using Ghost Grads.")

print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
print(
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}"
)


@dataclass
class CacheActivationsRunnerConfig(RunnerConfig):
Expand Down
15 changes: 10 additions & 5 deletions sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def run_evals(
l2_norm_out = torch.norm(sae_out, dim=-1)
l2_norm_ratio = l2_norm_out / l2_norm_in

wandb.log(
{
metrics = {
# l2 norms
f"metrics/l2_norm{suffix}": l2_norm_out.mean().item(),
f"metrics/l2_ratio{suffix}": l2_norm_ratio.mean().item(),
Expand All @@ -80,9 +79,15 @@ def run_evals(
f"metrics/ce_loss_without_sae{suffix}": ntp_loss,
f"metrics/ce_loss_with_sae{suffix}": recons_loss,
f"metrics/ce_loss_with_ablation{suffix}": zero_abl_loss,
},
step=n_training_steps,
)
}

if wandb.run is not None:
wandb.log(
metrics,
step=n_training_steps
)

return metrics

# head_index = sparse_autoencoder.cfg.hook_point_head_index

Expand Down

0 comments on commit c0b29cc

Please sign in to comment.