Skip to content

Commit

Permalink
get_shit_done
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Dec 10, 2023
1 parent 3843c39 commit ce73042
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 145 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,26 @@ model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader
)

```
## Tutorials

See the `tutorials` folder for the following tutorials:
- `exercises_and_solutions.ipynb`: A copy of Callum McDougall's SAE exercises which provide background knowledge on this codebase.
- `evaluating_your_sae.ipynb`: A quick/dirty notebook showing how to check L0 and Prediction loss with your SAE, as well as showing how to generate interactive dashboards using Callum's reporduction of [Anthropics interface](https://transformer-circuits.pub/2023/monosemantic-features#setup-interface).

## Example Dashboard

WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.

![screenshot](dashboard_screenshot.png)


## Example Output

Here's one feature we found in the residual stream of Layer 10 of GPT-2 Small:

![alt text](readme_screenshot_predict_pronoun_feature.png). Open `gpt2_resid_pre10_predict_pronoun_feature.html` in your browser to interact with the dashboard (WIP).

Note, probably this feature could split into more mono-semantic features in a larger SAE that had been trained for longer. (this was was only about 49152 features trained on 10M tokens from OpenWebText).


## Citations and References:
Expand Down
Binary file added dashboard_screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added readme_screenshot_predict_pronoun_feature.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 20 additions & 11 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,20 @@ class ActivationsStore:
"""
def __init__(
self, cfg, model: HookedTransformer,
data_path="NeelNanda/c4-code-tokenized-2b",
is_dataset_tokenized=True,
):
self.cfg = cfg
self.model = model
self.data_path = data_path
self.is_dataset_tokenized = is_dataset_tokenized
self.dataset = load_dataset(data_path, split="train", streaming=True)
self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True)
self.iterable_dataset = iter(self.dataset)

# check if it's tokenized
if "tokens" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = True
print("Dataset is tokenized! Updating config.")
elif "text" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = False
print("Dataset is not tokenized! Updating config.")

# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()
Expand All @@ -42,9 +46,13 @@ def get_batch_tokens(self):

# pbar = tqdm(total=batch_size, desc="Filling batches")
while batch_tokens.shape[0] < batch_size:
if not self.is_dataset_tokenized:
if not self.cfg.is_dataset_tokenized:
s = next(self.iterable_dataset)["text"]
tokens = self.model.to_tokens(s, truncate=False, move_to_device=True).squeeze(0)
tokens = self.model.to_tokens(
s,
truncate=True,
move_to_device=True,
).squeeze(0)
assert len(tokens.shape) == 1, f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
Expand All @@ -54,7 +62,8 @@ def get_batch_tokens(self):
)
token_len = tokens.shape[0]

while token_len > 0:
# TODO: Fix this so that we are limiting how many tokens we get from the same context.
while token_len > 0 and batch_tokens.shape[0] < batch_size:
# Space left in the current batch
space_left = context_size - current_length

Expand Down Expand Up @@ -127,15 +136,15 @@ def get_buffer(self, n_batches_in_buffer):
)

# Insert activations directly into pre-allocated buffer
pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")
# pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")
for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens).to(self.cfg.device)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations

pbar.update(1)
# pbar.update(1)

new_buffer = new_buffer.reshape(-1, d_in)
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]
Expand All @@ -155,7 +164,7 @@ def get_data_loader(self,) -> DataLoader:

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer //2),
[self.get_buffer(self.cfg.n_batches_in_buffer // 2),
self.storage_buffer]
)

Expand Down
18 changes: 17 additions & 1 deletion sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class LanguageModelSAERunnerConfig:
# Training Parameters
l1_coefficient: float = 1e-3
lr: float = 3e-4
lr_scheduler_name: str = "constant" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_warm_up_steps: int = 500
train_batch_size: int = 4096
context_size: int = 128

Expand Down Expand Up @@ -64,4 +66,18 @@ def __post_init__(self):
unique_id = wandb.util.generate_id()
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

assert self.dead_feature_window < self.feature_sampling_window, "dead_feature_window must be < feature_sampling_window"
assert self.dead_feature_window < self.feature_sampling_window, "dead_feature_window must be < feature_sampling_window"


# Print out some useful info:
n_tokens_per_buffer = self.store_batch_size * self.context_size * self.n_batches_in_buffer
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer
print(f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}")

total_training_steps = self.total_training_tokens // self.train_batch_size
print(f"Total training steps: {total_training_steps}")

# how many times will we sample dead neurons?
n_dead_feature_samples = total_training_steps // self.dead_feature_window - 1
print(f"n_dead_feature_samples: {n_dead_feature_samples}")
28 changes: 0 additions & 28 deletions sae_training/gather_activations.py

This file was deleted.

51 changes: 0 additions & 51 deletions sae_training/lm_datasets.py

This file was deleted.

91 changes: 91 additions & 0 deletions sae_training/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
'''
Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
'''
import math
from typing import Optional

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler


# None
# Linear Warmup and decay
# Cosine Annealing with Warmup
# Cosine Annealing with Warmup / Restarts
def get_scheduler(
scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs
):
"""
Loosely based on this, seemed simpler write this than import
transformers: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules
Args:
scheduler_name (Optional[str]): Name of the scheduler to use. If None, returns a constant scheduler
optimizer (optim.Optimizer): Optimizer to use
**kwargs: Additional arguments to pass to the scheduler including warm_up_steps,
training_steps, num_cycles, lr_end.
"""

def get_warmup_lambda(warm_up_steps, training_steps):
def lr_lambda(steps):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
return (training_steps - steps) / (
training_steps - warm_up_steps
)

return lr_lambda

# heavily derived from hugging face although copilot helped.
def get_warmup_cosine_lambda(warm_up_steps, training_steps, lr_end):
def lr_lambda(steps):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
progress = (steps - warm_up_steps) / (
training_steps - warm_up_steps
)
return lr_end + 0.5 * (1 - lr_end) * (
1 + math.cos(math.pi * progress)
)

return lr_lambda

if scheduler_name is None or scheduler_name.lower() == "constant":
return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0)
elif scheduler_name.lower() == "constantwithwarmup":
warm_up_steps = kwargs.get("warm_up_steps", 0)
return lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda steps: min(1.0, (steps + 1) / warm_up_steps),
)
elif scheduler_name.lower() == "linearwarmupdecay":
warm_up_steps = kwargs.get("warm_up_steps", 0)
training_steps = kwargs.get("training_steps")
lr_lambda = get_warmup_lambda(warm_up_steps, training_steps)
return lr_scheduler.LambdaLR(optimizer, lr_lambda)
elif scheduler_name.lower() == "cosineannealing":
training_steps = kwargs.get("training_steps")
eta_min = kwargs.get("lr_end", 0)
return lr_scheduler.CosineAnnealingLR(
optimizer, T_max=training_steps, eta_min=eta_min
)
elif scheduler_name.lower() == "cosineannealingwarmup":
warm_up_steps = kwargs.get("warm_up_steps", 0)
training_steps = kwargs.get("training_steps")
eta_min = kwargs.get("lr_end", 0)
lr_lambda = get_warmup_cosine_lambda(
warm_up_steps, training_steps, eta_min
)
return lr_scheduler.LambdaLR(optimizer, lr_lambda)
elif scheduler_name.lower() == "cosineannealingwarmrestarts":
training_steps = kwargs.get("training_steps")
eta_min = kwargs.get("lr_end", 0)
num_cycles = kwargs.get("num_cycles", 1)
T_0 = training_steps // num_cycles
return lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=T_0, eta_min=eta_min
)
else:
raise ValueError(f"Unsupported scheduler: {scheduler_name}")
Loading

0 comments on commit ce73042

Please sign in to comment.