',
+ '
',
+ )
)
# Also, delete the columns as appropriate if the number is between 0 and 5
else:
- html_output = html_output.replace('
| +0.000 |
', "")
+ html_output = html_output.replace(
+ '
| +0.000 |
',
+ "",
+ )
return html_output
-
-
-
def generate_seq_html(
vocab_dict: dict,
token_ids: List[str],
@@ -117,7 +136,9 @@ def generate_seq_html(
bold_idx: Optional[int] = None,
is_repeat: bool = False,
):
- assert len(token_ids) == len(feat_acts) == len(contribution_to_loss), "All input lists must be of the same length."
+ assert (
+ len(token_ids) == len(feat_acts) == len(contribution_to_loss)
+ ), "All input lists must be of the same length."
# ! Clip values in [0, 1] range (temporary)
bg_values = np.clip(feat_acts, 0, 1)
@@ -132,7 +153,6 @@ def generate_seq_html(
html_output = '
' # + repeat_obj
for i in range(len(token_ids)):
-
# Get background color, which is {0: transparent, +1: darkorange}
bg_val = bg_values[i]
bg_color = colors.rgb2hex(BG_COLOR_MAP(bg_val))
@@ -149,30 +169,25 @@ def generate_seq_html(
underline_color = f"rgb(255, {v}, {v})"
html_output += generate_tok_html(
- vocab_dict = vocab_dict,
- this_token = token_ids[i],
- underline_color = underline_color,
- bg_color = bg_color,
- pos_ids = pos_ids[i],
- neg_ids = neg_ids[i],
- pos_val = pos_val[i],
- neg_val = neg_val[i],
- is_bold = (bold_idx is not None) and (bold_idx == i),
- feat_act = feat_acts[i],
- contribution_to_loss = contribution_to_loss[i],
+ vocab_dict=vocab_dict,
+ this_token=token_ids[i],
+ underline_color=underline_color,
+ bg_color=bg_color,
+ pos_ids=pos_ids[i],
+ neg_ids=neg_ids[i],
+ pos_val=pos_val[i],
+ neg_val=neg_val[i],
+ is_bold=(bold_idx is not None) and (bold_idx == i),
+ feat_act=feat_acts[i],
+ contribution_to_loss=contribution_to_loss[i],
)
- html_output += '
'
+ html_output += "
"
return html_output
-
-
-
-
def generate_tables_html(
-
# First, all the arguments for the left-hand tables
neuron_alignment_indices: List[int],
neuron_alignment_values: List[float],
@@ -183,7 +198,6 @@ def generate_tables_html(
correlated_features_indices: List[int],
correlated_features_pearson: List[float],
correlated_features_l1: List[float],
-
# Second, all the arguments for the middle tables (neg/pos logits)
neg_str: List[str],
neg_values: List[float],
@@ -192,12 +206,12 @@ def generate_tables_html(
pos_values: List[float],
pos_bg_values: List[float],
):
- '''
+ """
See the file `threerow_table_template.html` (with the CSS in the other 3 files), for this to make more sense.
- '''
+ """
html_output = HTML_LEFT_TABLES
- for (letter, mylist, myformat) in zip(
+ for letter, mylist, myformat in zip(
"IVLIPCIPC",
[
neuron_alignment_indices,
@@ -214,18 +228,30 @@ def generate_tables_html(
correlated_neurons_pearson,
correlated_neurons_l1,
],
- [None, "+.2f", ".1%", None, "+.2f", "+.2f", None, "+.2f", "+.2f"]
+ [None, "+.2f", ".1%", None, "+.2f", "+.2f", None, "+.2f", "+.2f"],
):
- fn = lambda m: str(mylist[int(m.group(1))]) if myformat is None else format(mylist[int(m.group(1))], myformat)
+ fn = (
+ lambda m: str(mylist[int(m.group(1))])
+ if myformat is None
+ else format(mylist[int(m.group(1))], myformat)
+ )
html_output = re.sub(letter + "(\d)", fn, html_output, count=3)
-
html_output_2 = HTML_LOGIT_TABLES
- neg_bg_colors = [f"rgba(255, {int(255 * (1 - v))}, {int(255 * (1 - v))}, 0.5)" for v in neg_bg_values]
- pos_bg_colors = [f"rgba({int(255 * (1 - v))}, {int(255 * (1 - v))}, 255, 0.5)" for v in pos_bg_values]
-
- for (letter, mylist) in zip("SVCSVC", [neg_str, neg_values, neg_bg_colors, pos_str, pos_values, pos_bg_colors]):
+ neg_bg_colors = [
+ f"rgba(255, {int(255 * (1 - v))}, {int(255 * (1 - v))}, 0.5)"
+ for v in neg_bg_values
+ ]
+ pos_bg_colors = [
+ f"rgba({int(255 * (1 - v))}, {int(255 * (1 - v))}, 255, 0.5)"
+ for v in pos_bg_values
+ ]
+
+ for letter, mylist in zip(
+ "SVCSVC",
+ [neg_str, neg_values, neg_bg_colors, pos_str, pos_values, pos_bg_colors],
+ ):
if letter == "S":
fn = lambda m: str(mylist[int(m.group(1))]).replace(" ", " ")
elif letter == "V":
@@ -235,15 +261,17 @@ def generate_tables_html(
html_output_2 = re.sub(letter + "(\d)", fn, html_output_2, count=10)
return (html_output, html_output_2)
-
def generate_histograms(freq_hist_data, logits_hist_data) -> Tuple[str, str]:
- '''This generates both histograms at once.'''
+ """This generates both histograms at once."""
# Start off high, cause we want closer to orange than white for the left-most bars
freq_bar_values = freq_hist_data.bar_values
- freq_bar_values_clipped = [(0.4 * max(freq_bar_values) + 0.6 * v) / max(freq_bar_values) for v in freq_bar_values]
+ freq_bar_values_clipped = [
+ (0.4 * max(freq_bar_values) + 0.6 * v) / max(freq_bar_values)
+ for v in freq_bar_values
+ ]
freq_bar_colors = [colors.rgb2hex(BG_COLOR_MAP(v)) for v in freq_bar_values_clipped]
return (
@@ -258,7 +286,6 @@ def generate_histograms(freq_hist_data, logits_hist_data) -> Tuple[str, str]:
# Fill in all the freq logits histogram data
.replace("BAR_HEIGHTS_LOGITS", str(list(logits_hist_data.bar_heights)))
.replace("BAR_VALUES_LOGITS", str(list(logits_hist_data.bar_values)))
- .replace("TICK_VALS_LOGITS", str(list(logits_hist_data.tick_vals)))
+ .replace("TICK_VALS_LOGITS", str(list(logits_hist_data.tick_vals)))
),
)
-
diff --git a/sae_analysis/visualizer/model_fns.py b/sae_analysis/visualizer/model_fns.py
index 32bfba5c..068d94d8 100644
--- a/sae_analysis/visualizer/model_fns.py
+++ b/sae_analysis/visualizer/model_fns.py
@@ -9,9 +9,11 @@
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
+
@dataclass
class AutoEncoderConfig:
- '''Class for storing configuration parameters for the autoencoder'''
+ """Class for storing configuration parameters for the autoencoder"""
+
seed: int = 42
batch_size: int = 32
buffer_mult: int = 384
@@ -29,15 +31,14 @@ class AutoEncoderConfig:
model_batch_size: int = 64
def __post_init__(self):
- '''Using kwargs, so that we can pass in a dict of parameters which might be
- a superset of the above, without error.'''
+ """Using kwargs, so that we can pass in a dict of parameters which might be
+ a superset of the above, without error."""
self.buffer_size = self.batch_size * self.buffer_mult
self.buffer_batches = self.buffer_size // self.seq_len
self.dtype = DTYPES[self.enc_dtype]
self.d_hidden = self.d_mlp * self.dict_mult
-
class AutoEncoder(nn.Module):
def __init__(self, cfg: AutoEncoderConfig):
super().__init__()
@@ -46,8 +47,16 @@ def __init__(self, cfg: AutoEncoderConfig):
torch.manual_seed(cfg.seed)
# W_enc has shape (d_mlp, d_encoder), where d_encoder is a multiple of d_mlp (cause dictionary learning; overcomplete basis)
- self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_mlp, cfg.d_hidden, dtype=cfg.dtype)))
- self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg.d_hidden, cfg.d_mlp, dtype=cfg.dtype)))
+ self.W_enc = nn.Parameter(
+ torch.nn.init.kaiming_uniform_(
+ torch.empty(cfg.d_mlp, cfg.d_hidden, dtype=cfg.dtype)
+ )
+ )
+ self.W_dec = nn.Parameter(
+ torch.nn.init.kaiming_uniform_(
+ torch.empty(cfg.d_hidden, cfg.d_mlp, dtype=cfg.dtype)
+ )
+ )
self.b_enc = nn.Parameter(torch.zeros(cfg.d_hidden, dtype=cfg.dtype))
self.b_dec = nn.Parameter(torch.zeros(cfg.d_mlp, dtype=cfg.dtype))
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
@@ -69,13 +78,15 @@ def forward(self, x: torch.Tensor):
@torch.no_grad()
def remove_parallel_component_of_grads(self):
W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
- W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
+ W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
+ -1, keepdim=True
+ ) * W_dec_normed
self.W_dec.grad -= W_dec_grad_proj
@classmethod
def load_from_hf(cls, version, verbose=False):
"""
- Loads the saved autoencoder from HuggingFace.
+ Loads the saved autoencoder from HuggingFace.
Note, this is a classmethod, because we'll be using it as `auto_encoder = AutoEncoder.load_from_hf("run1")`
@@ -86,18 +97,25 @@ def load_from_hf(cls, version, verbose=False):
"""
assert version in ["run1", "run2"]
- version = 25 if version=="run1" else 47
+ version = 25 if version == "run1" else 47
- cfg: dict = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
+ cfg: dict = utils.download_file_from_hf(
+ "NeelNanda/sparse_autoencoder", f"{version}_cfg.json"
+ )
# There are some unnecessary params in cfg cause they're defined in post_init for config dataclass; we remove them
cfg.pop("buffer_batches", None)
cfg.pop("buffer_size", None)
- if verbose: pprint.pprint(cfg)
+ if verbose:
+ pprint.pprint(cfg)
cfg = AutoEncoderConfig(**cfg)
self = cls(cfg=cfg)
- self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
+ self.load_state_dict(
+ utils.download_file_from_hf(
+ "NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True
+ )
+ )
return self
def __repr__(self):
- return f"AutoEncoder(d_mlp={self.cfg.d_mlp}, dict_mult={self.cfg.dict_mult})"
\ No newline at end of file
+ return f"AutoEncoder(d_mlp={self.cfg.d_mlp}, dict_mult={self.cfg.dict_mult})"
diff --git a/sae_analysis/visualizer/utils_fns.py b/sae_analysis/visualizer/utils_fns.py
index c940dd23..51ab8980 100644
--- a/sae_analysis/visualizer/utils_fns.py
+++ b/sae_analysis/visualizer/utils_fns.py
@@ -1,29 +1,31 @@
-from jaxtyping import Float, Int
-from typing import Tuple, Optional, List, Union, Dict
import re
-import torch
-from torch import Tensor
-from transformer_lens import HookedTransformer
+from typing import Dict, List, Optional, Tuple, Union
+
import einops
import numpy as np
+import torch
from eindex import eindex
+from jaxtyping import Float, Int
+from torch import Tensor
+from transformer_lens import HookedTransformer
Arr = np.ndarray
+
def k_largest_indices(
- x: Float[Tensor, "rows cols"],
+ x: Float[Tensor, "rows cols"], # noqa
k: int,
largest: bool = True,
buffer: Tuple[int, int] = (5, 5),
-) -> Int[Tensor, "k 2"]:
- '''w
+) -> Int[Tensor, "k 2"]: # noqa
+ """w
Given a 2D array, returns the indices of the top or bottom `k` elements.
Also has a `buffer` argument, which makes sure we don't pick too close to the left/right of sequence. If `buffer`
is (5, 5), that means we shouldn't be allowed to pick the first or last 5 sequence positions, because we'll need
to append them to the left/right of the sequence. We should only be allowed from [5:-5] in this case.
- '''
- x = x[:, buffer[0]:-buffer[1]]
+ """
+ x = x[:, buffer[0] : -buffer[1]]
indices = x.flatten().topk(k=k, largest=largest).indices
rows = indices // x.size(1)
cols = indices % x.size(1) + buffer[0]
@@ -31,26 +33,26 @@ def k_largest_indices(
def sample_unique_indices(large_number, small_number):
- '''Samples a small number of unique indices from a large number of indices.'''
+ """Samples a small number of unique indices from a large number of indices."""
weights = torch.ones(large_number) # Equal weights for all indices
sampled_indices = torch.multinomial(weights, small_number, replacement=False)
return sampled_indices
def random_range_indices(
- x: Float[Tensor, "batch seq"],
+ x: Float[Tensor, "batch seq"], # noqa
bounds: Tuple[float, float],
k: int,
buffer: Tuple[int, int] = (5, 5),
-) -> Int[Tensor, "k 2"]:
- '''
+) -> Int[Tensor, "k 2"]: # noqa
+ """
Given a 2D array, returns the indices of `k` elements whose values are in the range `bounds`.
Will return fewer than `k` values if there aren't enough values in the range.
Also has a `buffer` argument, which makes sure we don't pick too close to the left/right of sequence.
- '''
+ """
# Limit x, because our indices (bolded words) shouldn't be too close to the left/right of sequence
- x = x[:, buffer[0]:-buffer[1]]
+ x = x[:, buffer[0] : -buffer[1]]
# Creat a mask for where x is in range, and get the indices as a tensor of shape (k, 2)
mask = (bounds[0] <= x) & (x <= bounds[1])
@@ -64,7 +66,6 @@ def random_range_indices(
return indices + torch.tensor([0, buffer[0]]).to(indices.device)
-
# # Example, where it'll pick the elements from the end of this 2D tensor, working backwards
# x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# k = 3
@@ -77,14 +78,13 @@ def random_range_indices(
# print(random_range_indices(x, bounds, k))
-
def to_str_tokens(vocab_dict: Dict[int, str], tokens: Union[int, torch.Tensor]):
- '''
+ """
If tokens is 1D, does the same thing as model.to_str_tokens.
If tokens is 2D or 3D, it flattens, does this thing, then reshapes.
Also, makes sure that line breaks are replaced with their repr.
- '''
+ """
if isinstance(tokens, int):
return vocab_dict[tokens]
@@ -101,8 +101,6 @@ def to_str_tokens(vocab_dict: Dict[int, str], tokens: Union[int, torch.Tensor]):
return reshape(str_tokens, tokens.shape)
-
-
def reshape(my_list, shape):
assert np.prod(shape) == len(my_list), "Shape is not compatible with list size"
assert len(shape) in [1, 2, 3], "Only shapes of length 1, 2, or 3 are supported"
@@ -114,39 +112,47 @@ def reshape(my_list, shape):
if len(shape) == 2:
return [[next(it) for _ in range(shape[1])] for _ in range(shape[0])]
- return [[[next(it) for _ in range(shape[2])] for _ in range(shape[1])] for _ in range(shape[0])]
-
+ return [
+ [[next(it) for _ in range(shape[2])] for _ in range(shape[1])]
+ for _ in range(shape[0])
+ ]
class TopK:
- '''
+ """
Wrapper around the object returned by torch.topk, which has the following 3 advantages:
-
+
> friendlier to type annotation
> easy device moving, without having to do it separately for values & indices
> easy indexing, without having to do it separately for values & indices
> other classic tensor operations, like .ndim, .shape, etc. work as expected
-
+
We initialise with a topk object, which is treated as a tuple of (values, indices).
- '''
+ """
def __init__(self, obj: Optional[Tuple[Arr, Arr]] = None):
- self.values: Arr = obj[0] if isinstance(obj[0], Arr) else obj[0].detach().cpu().numpy()
- self.indices: Arr = obj[1] if isinstance(obj[1], Arr) else obj[1].detach().cpu().numpy()
-
+ self.values: Arr = (
+ obj[0] if isinstance(obj[0], Arr) else obj[0].detach().cpu().numpy()
+ )
+ self.indices: Arr = (
+ obj[1] if isinstance(obj[1], Arr) else obj[1].detach().cpu().numpy()
+ )
+
def __getitem__(self, item):
return TopK((self.values[item], self.indices[item]))
def concat(self, other: "TopK"):
- '''If self is empty, returns the other (so we can start w/ empty & concatenate consistently).'''
+ """If self is empty, returns the other (so we can start w/ empty & concatenate consistently)."""
if self.numel() == 0:
return other
else:
- return TopK((
- np.concatenate((self.values, other.values)),
- np.concatenate((self.indices, other.indices))
- ))
-
+ return TopK(
+ (
+ np.concatenate((self.values, other.values)),
+ np.concatenate((self.indices, other.indices)),
+ )
+ )
+
@property
def ndim(self):
return self.values.ndim
@@ -158,31 +164,30 @@ def shape(self):
@property
def size(self):
return self.values.size()
-
+
def numel(self):
return self.values.size
class Output:
- '''So I can type annotate the output of transformer.'''
+ """So I can type annotate the output of transformer."""
+
loss: Tensor
logits: Tensor
-
def merge_lists(*lists):
return [item for sublist in lists for item in sublist]
-
def extract_and_remove_scripts(html_content) -> Tuple[str, str]:
# Pattern to find tags
- pattern = r''
+ pattern = r""
# Find all script tags
scripts = re.findall(pattern, html_content, re.DOTALL)
# Remove script tags from the original content
- html_without_scripts = re.sub(pattern, '', html_content, flags=re.DOTALL)
+ html_without_scripts = re.sub(pattern, "", html_content, flags=re.DOTALL)
return "\n".join(scripts), html_without_scripts
diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py
index 9b3e6c7b..0d3722a3 100644
--- a/sae_training/activations_store.py
+++ b/sae_training/activations_store.py
@@ -10,16 +10,20 @@
class ActivationsStore:
"""
Class for streaming tokens and generating and storing activations
- while training SAEs.
+ while training SAEs.
"""
+
def __init__(
- self, cfg, model: HookedTransformer, create_dataloader: bool = True,
+ self,
+ cfg,
+ model: HookedTransformer,
+ create_dataloader: bool = True,
):
self.cfg = cfg
self.model = model
self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True)
self.iterable_dataset = iter(self.dataset)
-
+
# check if it's tokenized
if "tokens" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = True
@@ -27,15 +31,16 @@ def __init__(
elif "text" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = False
print("Dataset is not tokenized! Updating config.")
-
+
if self.cfg.use_cached_activations:
# Sanity check: does the cache directory exist?
- assert os.path.exists(self.cfg.cached_activations_path), \
- f"Cache directory {self.cfg.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."
-
- self.next_cache_idx = 0 # which file to open next
- self.next_idx_within_buffer = 0 # where to start reading from in that file
-
+ assert os.path.exists(
+ self.cfg.cached_activations_path
+ ), f"Cache directory {self.cfg.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."
+
+ self.next_cache_idx = 0 # which file to open next
+ self.next_idx_within_buffer = 0 # where to start reading from in that file
+
# Check that we have enough data on disk
first_buffer = torch.load(f"{self.cfg.cached_activations_path}/0.pt")
buffer_size_on_disk = first_buffer.shape[0]
@@ -43,11 +48,12 @@ def __init__(
# Note: we're assuming all files have the same number of tokens
# (which seems reasonable imo since that's what our script does)
n_activations_on_disk = buffer_size_on_disk * n_buffers_on_disk
- assert n_activations_on_disk > self.cfg.total_training_tokens, \
- f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but cfg.total_training_tokens is {self.cfg.total_training_tokens/1e6:.1f}M."
-
+ assert (
+ n_activations_on_disk > self.cfg.total_training_tokens
+ ), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but cfg.total_training_tokens is {self.cfg.total_training_tokens/1e6:.1f}M."
+
# TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)
-
+
if create_dataloader:
# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
@@ -62,7 +68,9 @@ def get_batch_tokens(self):
context_size = self.cfg.context_size
device = self.cfg.device
- batch_tokens = torch.zeros(size=(0, context_size), device=device, dtype=torch.long, requires_grad=False)
+ batch_tokens = torch.zeros(
+ size=(0, context_size), device=device, dtype=torch.long, requires_grad=False
+ )
current_batch = []
current_length = 0
@@ -72,11 +80,13 @@ def get_batch_tokens(self):
if not self.cfg.is_dataset_tokenized:
s = next(self.iterable_dataset)["text"]
tokens = self.model.to_tokens(
- s,
- truncate=True,
+ s,
+ truncate=True,
move_to_device=True,
- ).squeeze(0)
- assert len(tokens.shape) == 1, f"tokens.shape should be 1D but was {tokens.shape}"
+ ).squeeze(0)
+ assert (
+ len(tokens.shape) == 1
+ ), f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
next(self.iterable_dataset)["tokens"],
@@ -87,8 +97,12 @@ def get_batch_tokens(self):
token_len = tokens.shape[0]
# TODO: Fix this so that we are limiting how many tokens we get from the same context.
-
- bos_token_id_tensor = torch.tensor([self.model.tokenizer.bos_token_id], device=tokens.device, dtype=torch.long)
+
+ bos_token_id_tensor = torch.tensor(
+ [self.model.tokenizer.bos_token_id],
+ device=tokens.device,
+ dtype=torch.long,
+ )
while token_len > 0 and batch_tokens.shape[0] < batch_size:
# Space left in the current batch
space_left = context_size - current_length
@@ -131,26 +145,16 @@ def get_batch_tokens(self):
return batch_tokens[:batch_size]
def get_activations(self, batch_tokens, get_loss=False):
-
act_name = self.cfg.hook_point
hook_point_layer = self.cfg.hook_point_layer
if self.cfg.hook_point_head_index is not None:
activations = self.model.run_with_cache(
- batch_tokens,
- names_filter=act_name,
- stop_at_layer=hook_point_layer+1
- )[
- 1
- ][act_name][:,:,self.cfg.hook_point_head_index]
+ batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
+ )[1][act_name][:, :, self.cfg.hook_point_head_index]
else:
activations = self.model.run_with_cache(
- batch_tokens,
- names_filter=act_name,
- stop_at_layer=hook_point_layer+1
- )[
- 1
- ][act_name]
-
+ batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
+ )[1][act_name]
return activations
@@ -164,35 +168,48 @@ def get_buffer(self, n_batches_in_buffer):
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor (flattened along all dims except d_in)
- new_buffer = torch.zeros((buffer_size, d_in), dtype=self.cfg.dtype,
- device=self.cfg.device)
+ new_buffer = torch.zeros(
+ (buffer_size, d_in), dtype=self.cfg.dtype, device=self.cfg.device
+ )
n_tokens_filled = 0
-
+
# The activations may be split across multiple files,
# Or we might only want a subset of one file (depending on the sizes)
while n_tokens_filled < buffer_size:
# Load the next file
# Make sure it exists
- if not os.path.exists(f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"):
- print("\n\nWarning: Ran out of cached activation files earlier than expected.")
- print(f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}.")
+ if not os.path.exists(
+ f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
+ ):
+ print(
+ "\n\nWarning: Ran out of cached activation files earlier than expected."
+ )
+ print(
+ f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}."
+ )
if buffer_size % self.cfg.total_training_tokens != 0:
- print("This might just be a rounding error — your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens")
+ print(
+ "This might just be a rounding error — your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens"
+ )
print(f"Returning a buffer of size {n_tokens_filled} instead.")
print("\n\n")
new_buffer = new_buffer[:n_tokens_filled]
break
- activations = torch.load(f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt")
-
+ activations = torch.load(
+ f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
+ )
+
# If we only want a subset of the file, take it
taking_subset_of_file = False
if n_tokens_filled + activations.shape[0] > buffer_size:
- activations = activations[:buffer_size - n_tokens_filled]
+ activations = activations[: buffer_size - n_tokens_filled]
taking_subset_of_file = True
-
+
# Add it to the buffer
- new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0]] = activations
-
+ new_buffer[
+ n_tokens_filled : n_tokens_filled + activations.shape[0]
+ ] = activations
+
# Update counters
n_tokens_filled += activations.shape[0]
if taking_subset_of_file:
@@ -200,7 +217,7 @@ def get_buffer(self, n_batches_in_buffer):
else:
self.next_cache_idx += 1
self.next_idx_within_buffer = 0
-
+
return new_buffer
refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
@@ -221,7 +238,7 @@ def get_buffer(self, n_batches_in_buffer):
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations
-
+
# pbar.update(1)
new_buffer = new_buffer.reshape(-1, d_in)
@@ -229,37 +246,43 @@ def get_buffer(self, n_batches_in_buffer):
return new_buffer
- def get_data_loader(self,) -> DataLoader:
- '''
+ def get_data_loader(
+ self,
+ ) -> DataLoader:
+ """
Return a torch.utils.dataloader which you can get batches from.
-
- Should automatically refill the buffer when it gets to n % full.
+
+ Should automatically refill the buffer when it gets to n % full.
(better mixing if you refill and shuffle regularly).
-
- '''
-
+
+ """
+
batch_size = self.cfg.train_batch_size
-
+
# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
- [self.get_buffer(self.cfg.n_batches_in_buffer // 2),
- self.storage_buffer]
+ [self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer]
)
-
+
mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
-
+
# 2. put 50 % in storage
- self.storage_buffer = mixing_buffer[:mixing_buffer.shape[0]//2]
-
+ self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
+
# 3. put other 50 % in a dataloader
- dataloader = iter(DataLoader(mixing_buffer[:mixing_buffer.shape[0]//2:], batch_size=batch_size, shuffle=True))
-
+ dataloader = iter(
+ DataLoader(
+ mixing_buffer[: mixing_buffer.shape[0] // 2 :],
+ batch_size=batch_size,
+ shuffle=True,
+ )
+ )
+
return dataloader
-
-
+
def next_batch(self):
"""
- Get the next batch from the current DataLoader.
+ Get the next batch from the current DataLoader.
If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
"""
try:
@@ -268,4 +291,4 @@ def next_batch(self):
except StopIteration:
# If the DataLoader is exhausted, create a new one
self.dataloader = self.get_data_loader()
- return next(self.dataloader)
\ No newline at end of file
+ return next(self.dataloader)
diff --git a/sae_training/cache_activations_runner.py b/sae_training/cache_activations_runner.py
index 7056a230..cb4daae6 100644
--- a/sae_training/cache_activations_runner.py
+++ b/sae_training/cache_activations_runner.py
@@ -14,36 +14,46 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
model = HookedTransformer.from_pretrained(cfg.model_name)
model.to(cfg.device)
activations_store = ActivationsStore(cfg, model, create_dataloader=False)
-
+
# if the activations directory exists and has files in it, raise an exception
if os.path.exists(activations_store.cfg.cached_activations_path):
if len(os.listdir(activations_store.cfg.cached_activations_path)) > 0:
- raise Exception(f"Activations directory ({activations_store.cfg.cached_activations_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files.")
+ raise Exception(
+ f"Activations directory ({activations_store.cfg.cached_activations_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
+ )
else:
os.makedirs(activations_store.cfg.cached_activations_path)
-
+
print(f"Started caching {cfg.total_training_tokens} activations")
- tokens_per_buffer = cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer
+ tokens_per_buffer = (
+ cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer
+ )
n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer)
for i in tqdm(range(n_buffers), desc="Caching activations"):
buffer = activations_store.get_buffer(cfg.n_batches_in_buffer)
torch.save(buffer, f"{activations_store.cfg.cached_activations_path}/{i}.pt")
del buffer
-
+
if i % cfg.shuffle_every_n_buffers == 0 and i > 0:
# Shuffle the buffers on disk
-
+
# Do random pairwise shuffling between the last shuffle_every_n_buffers buffers
for _ in range(cfg.n_shuffles_with_last_section):
- shuffle_activations_pairwise(activations_store.cfg.cached_activations_path,
- buffer_idx_range=(i - cfg.shuffle_every_n_buffers, i))
-
+ shuffle_activations_pairwise(
+ activations_store.cfg.cached_activations_path,
+ buffer_idx_range=(i - cfg.shuffle_every_n_buffers, i),
+ )
+
# Do more random pairwise shuffling between all the buffers
for _ in range(cfg.n_shuffles_in_entire_dir):
- shuffle_activations_pairwise(activations_store.cfg.cached_activations_path,
- buffer_idx_range=(0, i))
-
+ shuffle_activations_pairwise(
+ activations_store.cfg.cached_activations_path,
+ buffer_idx_range=(0, i),
+ )
+
# More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers)
for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"):
- shuffle_activations_pairwise(activations_store.cfg.cached_activations_path,
- buffer_idx_range=(0, n_buffers))
+ shuffle_activations_pairwise(
+ activations_store.cfg.cached_activations_path,
+ buffer_idx_range=(0, n_buffers),
+ )
diff --git a/sae_training/config.py b/sae_training/config.py
index 4ceb0a1b..9341d57d 100644
--- a/sae_training/config.py
+++ b/sae_training/config.py
@@ -32,7 +32,7 @@ class RunnerConfig(ABC):
# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
- store_batch_size: int = 32,
+ store_batch_size: int = (32,)
# Misc
device: str = "cpu"
@@ -134,10 +134,9 @@ def __post_init__(self):
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}"
)
-
if self.use_ghost_grads:
print("Using Ghost Grads.")
-
+
print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
diff --git a/sae_training/evals.py b/sae_training/evals.py
index 2d22138f..50517e91 100644
--- a/sae_training/evals.py
+++ b/sae_training/evals.py
@@ -10,57 +10,68 @@
@torch.no_grad()
-def run_evals(sparse_autoencoder: SparseAutoencoder, activation_store: ActivationsStore, model: HookedTransformer, n_training_steps: int):
-
+def run_evals(
+ sparse_autoencoder: SparseAutoencoder,
+ activation_store: ActivationsStore,
+ model: HookedTransformer,
+ n_training_steps: int,
+):
hook_point = sparse_autoencoder.cfg.hook_point
hook_point_layer = sparse_autoencoder.cfg.hook_point_layer
hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index
-
- ### Evals
+
+ ### Evals
eval_tokens = activation_store.get_batch_tokens()
-
+
# Get Reconstruction Score
- recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(sparse_autoencoder, model, activation_store, eval_tokens)
-
+ recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(
+ sparse_autoencoder, model, activation_store, eval_tokens
+ )
+
# get cache
- _, cache = model.run_with_cache(eval_tokens, prepend_bos=False, names_filter=[get_act_name("pattern", hook_point_layer), hook_point])
-
+ _, cache = model.run_with_cache(
+ eval_tokens,
+ prepend_bos=False,
+ names_filter=[get_act_name("pattern", hook_point_layer), hook_point],
+ )
+
# get act
if sparse_autoencoder.cfg.hook_point_head_index is not None:
- original_act = cache[sparse_autoencoder.cfg.hook_point][:,:,sparse_autoencoder.cfg.hook_point_head_index]
+ original_act = cache[sparse_autoencoder.cfg.hook_point][
+ :, :, sparse_autoencoder.cfg.hook_point_head_index
+ ]
else:
original_act = cache[sparse_autoencoder.cfg.hook_point]
-
- sae_out, feature_acts, _, _, _, _ = sparse_autoencoder(
- original_act
+
+ sae_out, feature_acts, _, _, _, _ = sparse_autoencoder(original_act)
+ patterns_original = (
+ cache[get_act_name("pattern", hook_point_layer)][:, hook_point_head_index]
+ .detach()
+ .cpu()
)
- patterns_original = cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu()
del cache
-
+
if "cuda" in str(model.cfg.device):
torch.cuda.empty_cache()
-
+
l2_norm_in = torch.norm(original_act, dim=-1)
l2_norm_out = torch.norm(sae_out, dim=-1)
l2_norm_ratio = l2_norm_out / l2_norm_in
-
+
wandb.log(
{
-
# l2 norms
"metrics/l2_norm": l2_norm_out.mean().item(),
"metrics/l2_ratio": l2_norm_ratio.mean().item(),
-
# CE Loss
"metrics/CE_loss_score": recons_score,
"metrics/ce_loss_without_sae": ntp_loss,
"metrics/ce_loss_with_sae": recons_loss,
"metrics/ce_loss_with_ablation": zero_abl_loss,
-
},
step=n_training_steps,
)
-
+
head_index = sparse_autoencoder.cfg.hook_point_head_index
def standard_replacement_hook(activations, hook):
@@ -68,44 +79,65 @@ def standard_replacement_hook(activations, hook):
return activations
def head_replacement_hook(activations, hook):
- new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype)
- activations[:,:,head_index] = new_actions
+ new_actions = sparse_autoencoder.forward(activations[:, :, head_index])[0].to(
+ activations.dtype
+ )
+ activations[:, :, head_index] = new_actions
return activations
head_index = sparse_autoencoder.cfg.hook_point_head_index
- replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook
-
+ replacement_hook = (
+ standard_replacement_hook if head_index is None else head_replacement_hook
+ )
+
# get attn when using reconstructed activations
with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]):
- _, new_cache = model.run_with_cache(eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)])
- patterns_reconstructed = new_cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu()
+ _, new_cache = model.run_with_cache(
+ eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]
+ )
+ patterns_reconstructed = (
+ new_cache[get_act_name("pattern", hook_point_layer)][
+ :, hook_point_head_index
+ ]
+ .detach()
+ .cpu()
+ )
del new_cache
-
+
# get attn when using reconstructed activations
with model.hooks(fwd_hooks=[(hook_point, partial(zero_ablate_hook))]):
- _, zero_ablation_cache = model.run_with_cache(eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)])
- patterns_ablation = zero_ablation_cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu()
+ _, zero_ablation_cache = model.run_with_cache(
+ eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)]
+ )
+ patterns_ablation = (
+ zero_ablation_cache[get_act_name("pattern", hook_point_layer)][
+ :, hook_point_head_index
+ ]
+ .detach()
+ .cpu()
+ )
del zero_ablation_cache
-
+
if sparse_autoencoder.cfg.hook_point_head_index:
-
- kl_result_reconstructed = kl_divergence_attention(patterns_original, patterns_reconstructed)
+ kl_result_reconstructed = kl_divergence_attention(
+ patterns_original, patterns_reconstructed
+ )
kl_result_reconstructed = kl_result_reconstructed.sum(dim=-1).numpy()
-
- kl_result_ablation = kl_divergence_attention(patterns_original, patterns_ablation)
+ kl_result_ablation = kl_divergence_attention(
+ patterns_original, patterns_ablation
+ )
kl_result_ablation = kl_result_ablation.sum(dim=-1).numpy()
wandb.log(
{
-
- "metrics/kldiv_reconstructed": kl_result_reconstructed.mean().item(),
- "metrics/kldiv_ablation": kl_result_ablation.mean().item(),
-
+ "metrics/kldiv_reconstructed": kl_result_reconstructed.mean().item(),
+ "metrics/kldiv_ablation": kl_result_ablation.mean().item(),
},
step=n_training_steps,
)
+
@torch.no_grad()
def get_recons_loss(sparse_autoencoder, model, activation_store, batch_tokens):
hook_point = activation_store.cfg.hook_point
@@ -118,11 +150,15 @@ def standard_replacement_hook(activations, hook):
return activations
def head_replacement_hook(activations, hook):
- new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype)
- activations[:,:,head_index] = new_actions
+ new_actions = sparse_autoencoder.forward(activations[:, :, head_index])[0].to(
+ activations.dtype
+ )
+ activations[:, :, head_index] = new_actions
return activations
- replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook
+ replacement_hook = (
+ standard_replacement_hook if head_index is None else head_replacement_hook
+ )
recons_loss = model.run_with_hooks(
batch_tokens,
return_type="loss",
@@ -149,9 +185,8 @@ def zero_ablate_hook(mlp_post, hook):
def kl_divergence_attention(y_true, y_pred):
-
# Compute log probabilities for KL divergence
log_y_true = torch.log2(y_true + 1e-10)
log_y_pred = torch.log2(y_pred + 1e-10)
- return y_true * (log_y_true - log_y_pred)
\ No newline at end of file
+ return y_true * (log_y_true - log_y_pred)
diff --git a/sae_training/geom_median/setup.py b/sae_training/geom_median/setup.py
index ca52ba1d..d49bb7d1 100644
--- a/sae_training/geom_median/setup.py
+++ b/sae_training/geom_median/setup.py
@@ -24,6 +24,6 @@
packages=setuptools.find_packages(where="src"),
python_requires=">=3.6",
install_requires=[
- 'numpy>=1.18.1',
- ]
+ "numpy>=1.18.1",
+ ],
)
diff --git a/sae_training/geom_median/src/geom_median/__init__.py b/sae_training/geom_median/src/geom_median/__init__.py
index 8b137891..e69de29b 100644
--- a/sae_training/geom_median/src/geom_median/__init__.py
+++ b/sae_training/geom_median/src/geom_median/__init__.py
@@ -1 +0,0 @@
-
diff --git a/sae_training/geom_median/src/geom_median/numpy/__init__.py b/sae_training/geom_median/src/geom_median/numpy/__init__.py
index 8b9fe0e3..6e0a637b 100644
--- a/sae_training/geom_median/src/geom_median/numpy/__init__.py
+++ b/sae_training/geom_median/src/geom_median/numpy/__init__.py
@@ -1,3 +1,3 @@
from .main import compute_geometric_median
-__all__ = [compute_geometric_median]
\ No newline at end of file
+__all__ = [compute_geometric_median]
diff --git a/sae_training/geom_median/src/geom_median/numpy/main.py b/sae_training/geom_median/src/geom_median/numpy/main.py
index 96f8f131..a52aef68 100644
--- a/sae_training/geom_median/src/geom_median/numpy/main.py
+++ b/sae_training/geom_median/src/geom_median/numpy/main.py
@@ -4,34 +4,47 @@
from .weiszfeld_list_of_array import geometric_median_list_of_array
from . import utils
+
def compute_geometric_median(
- points, weights=None, per_component=False, skip_typechecks=False,
- eps=1e-6, maxiter=100, ftol=1e-20
+ points,
+ weights=None,
+ per_component=False,
+ skip_typechecks=False,
+ eps=1e-6,
+ maxiter=100,
+ ftol=1e-20,
):
- """ Compute the geometric median of points `points` with weights given by `weights`.
- """
- if weights is None:
- n = len(points)
- weights = np.ones(n)
- if type(points) == np.ndarray:
- # `points` are given as an array of shape (n, d)
- points = [p for p in points] # translate to list of arrays format
- if type(points) not in [list, tuple]:
- raise ValueError(
- f"We expect `points` as a list of arrays or a list of tuples of arrays. Got {type(points)}"
- )
- if type(points[0]) == np.ndarray: # `points` are given in list of arrays format
- if not skip_typechecks:
- utils.check_list_of_array_format(points)
- to_return = geometric_median_array(points, weights, eps, maxiter, ftol)
- elif type(points[0]) in [list, tuple]: # `points` are in list of list of arrays format
- if not skip_typechecks:
- utils.check_list_of_list_of_array_format(points)
- if per_component:
- to_return = geometric_median_per_component(points, weights, eps, maxiter, ftol)
- else:
- to_return = geometric_median_list_of_array(points, weights, eps, maxiter, ftol)
- else:
- raise ValueError(f"Unexpected format {type(points[0])} for list of list format.")
- return to_return
-
\ No newline at end of file
+ """Compute the geometric median of points `points` with weights given by `weights`."""
+ if weights is None:
+ n = len(points)
+ weights = np.ones(n)
+ if type(points) == np.ndarray:
+ # `points` are given as an array of shape (n, d)
+ points = [p for p in points] # translate to list of arrays format
+ if type(points) not in [list, tuple]:
+ raise ValueError(
+ f"We expect `points` as a list of arrays or a list of tuples of arrays. Got {type(points)}"
+ )
+ if type(points[0]) == np.ndarray: # `points` are given in list of arrays format
+ if not skip_typechecks:
+ utils.check_list_of_array_format(points)
+ to_return = geometric_median_array(points, weights, eps, maxiter, ftol)
+ elif type(points[0]) in [
+ list,
+ tuple,
+ ]: # `points` are in list of list of arrays format
+ if not skip_typechecks:
+ utils.check_list_of_list_of_array_format(points)
+ if per_component:
+ to_return = geometric_median_per_component(
+ points, weights, eps, maxiter, ftol
+ )
+ else:
+ to_return = geometric_median_list_of_array(
+ points, weights, eps, maxiter, ftol
+ )
+ else:
+ raise ValueError(
+ f"Unexpected format {type(points[0])} for list of list format."
+ )
+ return to_return
diff --git a/sae_training/geom_median/src/geom_median/numpy/utils.py b/sae_training/geom_median/src/geom_median/numpy/utils.py
index f1d866b5..29382e4a 100644
--- a/sae_training/geom_median/src/geom_median/numpy/utils.py
+++ b/sae_training/geom_median/src/geom_median/numpy/utils.py
@@ -1,35 +1,34 @@
from itertools import zip_longest
import numpy as np
-def check_list_of_array_format(points):
- check_shapes_compatibility(points, -1)
-def check_list_of_list_of_array_format(points):
- # each element of `points` is a list of arrays of compatible shapes
- components = zip_longest(*points, fillvalue=np.array(0))
- for i, component in enumerate(components):
- check_shapes_compatibility(component, i)
+def check_list_of_array_format(points):
+ check_shapes_compatibility(points, -1)
-def check_shapes_compatibility(list_of_arrays, i):
- arr0 = list_of_arrays[0]
- if not isinstance(arr0, np.ndarray):
- raise ValueError(
- "Expected points of format list of `numpy.ndarray`s.",
- f"Got {type(arr0)} for component {i} of point 0."
- )
- shape = arr0.shape
- for j, arr in enumerate(list_of_arrays[1:]):
- if not isinstance(arr, np.ndarray):
- raise ValueError(
- f"Expected points of format list of `numpy.ndarray`s. Got {type(arr)}",
- f"for component {i} of point {j+1}."
- )
- if arr.shape != shape:
- raise ValueError(
- f"Expected shape {shape} for component {i} of point {j+1}.",
- f"Got shape {arr.shape} instead."
- )
-
+def check_list_of_list_of_array_format(points):
+ # each element of `points` is a list of arrays of compatible shapes
+ components = zip_longest(*points, fillvalue=np.array(0))
+ for i, component in enumerate(components):
+ check_shapes_compatibility(component, i)
+def check_shapes_compatibility(list_of_arrays, i):
+ arr0 = list_of_arrays[0]
+ if not isinstance(arr0, np.ndarray):
+ raise ValueError(
+ "Expected points of format list of `numpy.ndarray`s.",
+ f"Got {type(arr0)} for component {i} of point 0.",
+ )
+ shape = arr0.shape
+ for j, arr in enumerate(list_of_arrays[1:]):
+ if not isinstance(arr, np.ndarray):
+ raise ValueError(
+ f"Expected points of format list of `numpy.ndarray`s. Got {type(arr)}",
+ f"for component {i} of point {j+1}.",
+ )
+ if arr.shape != shape:
+ raise ValueError(
+ f"Expected shape {shape} for component {i} of point {j+1}.",
+ f"Got shape {arr.shape} instead.",
+ )
diff --git a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py
index ccddd642..5b352a7c 100644
--- a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py
+++ b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py
@@ -1,12 +1,13 @@
import numpy as np
from types import SimpleNamespace
+
def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
:param points: list of length :math:`n`, whose elements are each a ``numpy.array`` of shape ``(d,)``
:param weights: ``numpy.array`` of shape :math:``(n,)``.
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
- Equivalently, this is a smoothing parameter. Default 1e-6.
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
+ Equivalently, this is a smoothing parameter. Default 1e-6.
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
:return: SimpleNamespace object with fields
@@ -35,22 +36,25 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
return SimpleNamespace(
median=median,
- termination="function value converged within tolerance" if early_termination else "maximum iterations reached",
+ termination="function value converged within tolerance"
+ if early_termination
+ else "maximum iterations reached",
logs=logs,
)
+
def geometric_median_per_component(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
:param points: list of length :math:``n``, where each element is itself a list of ``numpy.ndarray``.
Each inner list has the same "shape".
:param weights: ``numpy.ndarray`` of shape :math:``(n,)``.
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
- Equivalently, this is a smoothing parameter. Default 1e-6.
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
+ Equivalently, this is a smoothing parameter. Default 1e-6.
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
:return: SimpleNamespace object with fields
- `median`: estimate of the geometric median, which is a list of ``numpy.ndarray`` of the same "shape" as the input.
- - `termination`: string explaining how the algorithm terminated, one for each component.
+ - `termination`: string explaining how the algorithm terminated, one for each component.
- `logs`: function values encountered through the course of the algorithm.
"""
components = list(zip(*points))
@@ -64,6 +68,7 @@ def geometric_median_per_component(points, weights, eps=1e-6, maxiter=100, ftol=
logs.append(ret.logs)
return SimpleNamespace(median=median, termination=termination, logs=logs)
+
def weighted_average(points, weights):
"""
Compute a weighted average of rows of `points`, with each row weighted by the corresponding entry in `weights`
@@ -75,4 +80,6 @@ def weighted_average(points, weights):
def geometric_median_objective(median, points, weights):
- return np.average([np.linalg.norm((p - median).reshape(-1)) for p in points], weights=weights)
\ No newline at end of file
+ return np.average(
+ [np.linalg.norm((p - median).reshape(-1)) for p in points], weights=weights
+ )
diff --git a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py
index ab21f99f..172894f0 100644
--- a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py
+++ b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py
@@ -1,13 +1,14 @@
import numpy as np
from types import SimpleNamespace
+
def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
:param points: list of length :math:``n``, where each element is itself a list of ``numpy.ndarray``.
Each inner list has the same "shape".
:param weights: ``numpy.ndarray`` of shape :math:``(n,)``.
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
- Equivalently, this is a smoothing parameter. Default 1e-6.
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
+ Equivalently, this is a smoothing parameter. Default 1e-6.
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
:return: SimpleNamespace object with fields
@@ -24,7 +25,9 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=
early_termination = False
for _ in range(maxiter):
prev_obj_value = objective_value
- new_weights = weights / np.maximum(eps, np.asarray([l2distance(p, median) for p in points]))
+ new_weights = weights / np.maximum(
+ eps, np.asarray([l2distance(p, median) for p in points])
+ )
median = weighted_average(points, new_weights)
objective_value = geometric_median_objective(median, points, weights)
@@ -35,19 +38,27 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=
return SimpleNamespace(
median=median,
- termination="function value converged within tolerance" if early_termination else "maximum iterations reached",
+ termination="function value converged within tolerance"
+ if early_termination
+ else "maximum iterations reached",
logs=logs,
)
+
def weighted_average(points, weights):
- return [np.average(component, weights=weights, axis=0) for component in zip(*points)]
+ return [
+ np.average(component, weights=weights, axis=0) for component in zip(*points)
+ ]
+
def geometric_median_objective(median, points, weights):
return np.average([l2distance(p, median) for p in points], weights=weights)
+
# Simple operators for list-of-array format
def l2distance(p1, p2):
return np.linalg.norm([np.linalg.norm(x1 - x2) for (x1, x2) in zip(p1, p2)])
+
def subtract(p1, p2):
- return [x1 - x2 for (x1, x2) in zip(p1, p2)]
\ No newline at end of file
+ return [x1 - x2 for (x1, x2) in zip(p1, p2)]
diff --git a/sae_training/geom_median/src/geom_median/torch/__init__.py b/sae_training/geom_median/src/geom_median/torch/__init__.py
index 8b9fe0e3..6e0a637b 100644
--- a/sae_training/geom_median/src/geom_median/torch/__init__.py
+++ b/sae_training/geom_median/src/geom_median/torch/__init__.py
@@ -1,3 +1,3 @@
from .main import compute_geometric_median
-__all__ = [compute_geometric_median]
\ No newline at end of file
+__all__ = [compute_geometric_median]
diff --git a/sae_training/geom_median/src/geom_median/torch/main.py b/sae_training/geom_median/src/geom_median/torch/main.py
index 2fd2b187..86eb4866 100644
--- a/sae_training/geom_median/src/geom_median/torch/main.py
+++ b/sae_training/geom_median/src/geom_median/torch/main.py
@@ -4,35 +4,48 @@
from .weiszfeld_list_of_array import geometric_median_list_of_array
from . import utils
+
def compute_geometric_median(
- points, weights=None, per_component=False, skip_typechecks=False,
- eps=1e-6, maxiter=100, ftol=1e-20
+ points,
+ weights=None,
+ per_component=False,
+ skip_typechecks=False,
+ eps=1e-6,
+ maxiter=100,
+ ftol=1e-20,
):
- """ Compute the geometric median of points `points` with weights given by `weights`.
- """
- if type(points) == torch.Tensor:
- # `points` are given as an array of shape (n, d)
- points = [p for p in points] # translate to list of arrays format
- if type(points) not in [list, tuple]:
- raise ValueError(
- f"We expect `points` as a list of arrays or a list of tuples of arrays. Got {type(points)}"
- )
- if type(points[0]) == torch.Tensor: # `points` are given in list of arrays format
- if not skip_typechecks:
- utils.check_list_of_array_format(points)
- if weights is None:
- weights = torch.ones(len(points), device=points[0].device)
- to_return = geometric_median_array(points, weights, eps, maxiter, ftol)
- elif type(points[0]) in [list, tuple]: # `points` are in list of list of arrays format
- if not skip_typechecks:
- utils.check_list_of_list_of_array_format(points)
- if weights is None:
- weights = torch.ones(len(points), device=points[0][0].device)
- if per_component:
- to_return = geometric_median_per_component(points, weights, eps, maxiter, ftol)
- else:
- to_return = geometric_median_list_of_array(points, weights, eps, maxiter, ftol)
- else:
- raise ValueError(f"Unexpected format {type(points[0])} for list of list format.")
- return to_return
-
+ """Compute the geometric median of points `points` with weights given by `weights`."""
+ if type(points) == torch.Tensor:
+ # `points` are given as an array of shape (n, d)
+ points = [p for p in points] # translate to list of arrays format
+ if type(points) not in [list, tuple]:
+ raise ValueError(
+ f"We expect `points` as a list of arrays or a list of tuples of arrays. Got {type(points)}"
+ )
+ if type(points[0]) == torch.Tensor: # `points` are given in list of arrays format
+ if not skip_typechecks:
+ utils.check_list_of_array_format(points)
+ if weights is None:
+ weights = torch.ones(len(points), device=points[0].device)
+ to_return = geometric_median_array(points, weights, eps, maxiter, ftol)
+ elif type(points[0]) in [
+ list,
+ tuple,
+ ]: # `points` are in list of list of arrays format
+ if not skip_typechecks:
+ utils.check_list_of_list_of_array_format(points)
+ if weights is None:
+ weights = torch.ones(len(points), device=points[0][0].device)
+ if per_component:
+ to_return = geometric_median_per_component(
+ points, weights, eps, maxiter, ftol
+ )
+ else:
+ to_return = geometric_median_list_of_array(
+ points, weights, eps, maxiter, ftol
+ )
+ else:
+ raise ValueError(
+ f"Unexpected format {type(points[0])} for list of list format."
+ )
+ return to_return
diff --git a/sae_training/geom_median/src/geom_median/torch/utils.py b/sae_training/geom_median/src/geom_median/torch/utils.py
index bd1c449c..02b69741 100644
--- a/sae_training/geom_median/src/geom_median/torch/utils.py
+++ b/sae_training/geom_median/src/geom_median/torch/utils.py
@@ -1,35 +1,34 @@
from itertools import zip_longest
import torch
-def check_list_of_array_format(points):
- check_shapes_compatibility(points, -1)
-def check_list_of_list_of_array_format(points):
- # each element of `points` is a list of arrays of compatible shapes
- components = zip_longest(*points, fillvalue=torch.Tensor())
- for i, component in enumerate(components):
- check_shapes_compatibility(component, i)
+def check_list_of_array_format(points):
+ check_shapes_compatibility(points, -1)
-def check_shapes_compatibility(list_of_arrays, i):
- arr0 = list_of_arrays[0]
- if not isinstance(arr0, torch.Tensor):
- raise ValueError(
- "Expected points of format list of `torch.Tensor`s.",
- f"Got {type(arr0)} for component {i} of point 0."
- )
- shape = arr0.shape
- for j, arr in enumerate(list_of_arrays[1:]):
- if not isinstance(arr, torch.Tensor):
- raise ValueError(
- f"Expected points of format list of `torch.Tensor`s. Got {type(arr)}",
- f"for component {i} of point {j+1}."
- )
- if arr.shape != shape:
- raise ValueError(
- f"Expected shape {shape} for component {i} of point {j+1}.",
- f"Got shape {arr.shape} instead."
- )
-
+def check_list_of_list_of_array_format(points):
+ # each element of `points` is a list of arrays of compatible shapes
+ components = zip_longest(*points, fillvalue=torch.Tensor())
+ for i, component in enumerate(components):
+ check_shapes_compatibility(component, i)
+def check_shapes_compatibility(list_of_arrays, i):
+ arr0 = list_of_arrays[0]
+ if not isinstance(arr0, torch.Tensor):
+ raise ValueError(
+ "Expected points of format list of `torch.Tensor`s.",
+ f"Got {type(arr0)} for component {i} of point 0.",
+ )
+ shape = arr0.shape
+ for j, arr in enumerate(list_of_arrays[1:]):
+ if not isinstance(arr, torch.Tensor):
+ raise ValueError(
+ f"Expected points of format list of `torch.Tensor`s. Got {type(arr)}",
+ f"for component {i} of point {j+1}.",
+ )
+ if arr.shape != shape:
+ raise ValueError(
+ f"Expected shape {shape} for component {i} of point {j+1}.",
+ f"Got shape {arr.shape} instead.",
+ )
diff --git a/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py b/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py
index fd8b680c..ae337a02 100644
--- a/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py
+++ b/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py
@@ -9,8 +9,8 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
:param points: list of length :math:`n`, whose elements are each a ``torch.Tensor`` of shape ``(d,)``
:param weights: ``torch.Tensor`` of shape :math:``(n,)``.
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
- Equivalently, this is a smoothing parameter. Default 1e-6.
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
+ Equivalently, this is a smoothing parameter. Default 1e-6.
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
:return: SimpleNamespace object with fields
@@ -30,7 +30,9 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
pbar = tqdm.tqdm(range(maxiter))
for _ in pbar:
prev_obj_value = objective_value
- norms = torch.stack([torch.linalg.norm((p - median).view(-1)) for p in points])
+ norms = torch.stack(
+ [torch.linalg.norm((p - median).view(-1)) for p in points]
+ )
new_weights = weights / torch.clamp(norms, min=eps)
median = weighted_average(points, new_weights)
@@ -39,29 +41,32 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
early_termination = True
break
-
+
pbar.set_description(f"Objective value: {objective_value:.4f}")
median = weighted_average(points, new_weights) # allow autodiff to track it
return SimpleNamespace(
median=median,
new_weights=new_weights,
- termination="function value converged within tolerance" if early_termination else "maximum iterations reached",
+ termination="function value converged within tolerance"
+ if early_termination
+ else "maximum iterations reached",
logs=logs,
)
+
def geometric_median_per_component(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
:param points: list of length :math:``n``, where each element is itself a list of ``numpy.ndarray``.
Each inner list has the same "shape".
:param weights: ``numpy.ndarray`` of shape :math:``(n,)``.
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
- Equivalently, this is a smoothing parameter. Default 1e-6.
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
+ Equivalently, this is a smoothing parameter. Default 1e-6.
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
:return: SimpleNamespace object with fields
- `median`: estimate of the geometric median, which is a list of ``numpy.ndarray`` of the same "shape" as the input.
- - `termination`: string explaining how the algorithm terminated, one for each component.
+ - `termination`: string explaining how the algorithm terminated, one for each component.
- `logs`: function values encountered through the course of the algorithm.
"""
components = list(zip(*points))
@@ -78,6 +83,7 @@ def geometric_median_per_component(points, weights, eps=1e-6, maxiter=100, ftol=
logs.append(ret.logs)
return SimpleNamespace(median=median, termination=termination, logs=logs)
+
def weighted_average(points, weights):
weights = weights / weights.sum()
ret = points[0] * weights[0]
@@ -85,6 +91,10 @@ def weighted_average(points, weights):
ret += points[i] * weights[i]
return ret
+
@torch.no_grad()
def geometric_median_objective(median, points, weights):
- return np.average([torch.linalg.norm((p - median).reshape(-1)).item() for p in points], weights=weights.cpu())
+ return np.average(
+ [torch.linalg.norm((p - median).reshape(-1)).item() for p in points],
+ weights=weights.cpu(),
+ )
diff --git a/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py b/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py
index bc8ae4e3..2920a480 100644
--- a/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py
+++ b/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py
@@ -2,13 +2,14 @@
import torch
from types import SimpleNamespace
+
def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
:param points: list of length :math:``n``, where each element is itself a list of ``torch.Tensor``.
Each inner list has the same "shape".
:param weights: ``torch.Tensor`` of shape :math:``(n,)``.
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
- Equivalently, this is a smoothing parameter. Default 1e-6.
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
+ Equivalently, this is a smoothing parameter. Default 1e-6.
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
:return: SimpleNamespace object with fields
@@ -28,7 +29,7 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=
for _ in range(maxiter):
prev_obj_value = objective_value
denom = torch.stack([l2distance(p, median) for p in points])
- new_weights = weights / torch.clamp(denom, min=eps)
+ new_weights = weights / torch.clamp(denom, min=eps)
median = weighted_average(points, new_weights)
objective_value = geometric_median_objective(median, points, weights)
@@ -36,30 +37,43 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
early_termination = True
break
-
+
median = weighted_average(points, new_weights) # for autodiff
return SimpleNamespace(
median=median,
new_weights=new_weights,
- termination="function value converged within tolerance" if early_termination else "maximum iterations reached",
+ termination="function value converged within tolerance"
+ if early_termination
+ else "maximum iterations reached",
logs=logs,
)
+
def weighted_average_component(points, weights):
ret = points[0] * weights[0]
for i in range(1, len(points)):
ret += points[i] * weights[i]
return ret
+
def weighted_average(points, weights):
weights = weights / weights.sum()
- return [weighted_average_component(component, weights=weights) for component in zip(*points)]
+ return [
+ weighted_average_component(component, weights=weights)
+ for component in zip(*points)
+ ]
+
@torch.no_grad()
def geometric_median_objective(median, points, weights):
- return np.average([l2distance(p, median).item() for p in points], weights=weights.cpu())
+ return np.average(
+ [l2distance(p, median).item() for p in points], weights=weights.cpu()
+ )
+
@torch.no_grad()
def l2distance(p1, p2):
- return torch.linalg.norm(torch.stack([torch.linalg.norm(x1 - x2) for (x1, x2) in zip(p1, p2)]))
+ return torch.linalg.norm(
+ torch.stack([torch.linalg.norm(x1 - x2) for (x1, x2) in zip(p1, p2)])
+ )
diff --git a/sae_training/lm_runner.py b/sae_training/lm_runner.py
index d1609bb6..eb3b482f 100644
--- a/sae_training/lm_runner.py
+++ b/sae_training/lm_runner.py
@@ -10,13 +10,16 @@
def language_model_sae_runner(cfg):
- """
-
- """
-
+ """ """
+
if cfg.from_pretrained_path is not None:
- model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
- cfg.from_pretrained_path)
+ (
+ model,
+ sparse_autoencoder,
+ activations_loader,
+ ) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
+ cfg.from_pretrained_path
+ )
cfg = sparse_autoencoder.cfg
else:
loader = LMSparseAutoencoderSessionloader(cfg)
@@ -24,32 +27,35 @@ def language_model_sae_runner(cfg):
if cfg.log_to_wandb:
wandb.init(project=cfg.wandb_project, config=cfg, name=cfg.run_name)
-
+
# train SAE
sparse_autoencoder = train_sae_on_language_model(
- model, sparse_autoencoder, activations_loader,
+ model,
+ sparse_autoencoder,
+ activations_loader,
n_checkpoints=cfg.n_checkpoints,
- batch_size = cfg.train_batch_size,
- feature_sampling_window = cfg.feature_sampling_window,
- dead_feature_threshold = cfg.dead_feature_threshold,
- use_wandb = cfg.log_to_wandb,
- wandb_log_frequency = cfg.wandb_log_frequency
+ batch_size=cfg.train_batch_size,
+ feature_sampling_window=cfg.feature_sampling_window,
+ dead_feature_threshold=cfg.dead_feature_threshold,
+ use_wandb=cfg.log_to_wandb,
+ wandb_log_frequency=cfg.wandb_log_frequency,
)
# save sae to checkpoints folder
path = f"{cfg.checkpoint_path}/final_{sparse_autoencoder.get_name()}.pt"
sparse_autoencoder.save_model(path)
-
+
# upload to wandb
if cfg.log_to_wandb:
model_artifact = wandb.Artifact(
- f"{sparse_autoencoder.get_name()}", type="model", metadata=dict(cfg.__dict__)
+ f"{sparse_autoencoder.get_name()}",
+ type="model",
+ metadata=dict(cfg.__dict__),
)
model_artifact.add_file(path)
wandb.log_artifact(model_artifact, aliases=["final_model"])
-
if cfg.log_to_wandb:
wandb.finish()
-
- return sparse_autoencoder
\ No newline at end of file
+
+ return sparse_autoencoder
diff --git a/sae_training/optim.py b/sae_training/optim.py
index f3b00424..27ea8084 100644
--- a/sae_training/optim.py
+++ b/sae_training/optim.py
@@ -1,6 +1,6 @@
-'''
+"""
Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
-'''
+"""
import math
from typing import Optional
@@ -12,9 +12,7 @@
# Linear Warmup and decay
# Cosine Annealing with Warmup
# Cosine Annealing with Warmup / Restarts
-def get_scheduler(
- scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs
-):
+def get_scheduler(scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs):
"""
Loosely based on this, seemed simpler write this than import
transformers: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules
@@ -31,9 +29,7 @@ def lr_lambda(steps):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
- return (training_steps - steps) / (
- training_steps - warm_up_steps
- )
+ return (training_steps - steps) / (training_steps - warm_up_steps)
return lr_lambda
@@ -43,15 +39,11 @@ def lr_lambda(steps):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
- progress = (steps - warm_up_steps) / (
- training_steps - warm_up_steps
- )
- return lr_end + 0.5 * (1 - lr_end) * (
- 1 + math.cos(math.pi * progress)
- )
+ progress = (steps - warm_up_steps) / (training_steps - warm_up_steps)
+ return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress))
return lr_lambda
-
+
if scheduler_name is None or scheduler_name.lower() == "constant":
return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0)
elif scheduler_name.lower() == "constantwithwarmup":
@@ -75,9 +67,7 @@ def lr_lambda(steps):
warm_up_steps = kwargs.get("warm_up_steps", 0)
training_steps = kwargs.get("training_steps")
eta_min = kwargs.get("lr_end", 0)
- lr_lambda = get_warmup_cosine_lambda(
- warm_up_steps, training_steps, eta_min
- )
+ lr_lambda = get_warmup_cosine_lambda(warm_up_steps, training_steps, eta_min)
return lr_scheduler.LambdaLR(optimizer, lr_lambda)
elif scheduler_name.lower() == "cosineannealingwarmrestarts":
training_steps = kwargs.get("training_steps")
@@ -88,4 +78,4 @@ def lr_lambda(steps):
optimizer, T_0=T_0, eta_min=eta_min
)
else:
- raise ValueError(f"Unsupported scheduler: {scheduler_name}")
\ No newline at end of file
+ raise ValueError(f"Unsupported scheduler: {scheduler_name}")
diff --git a/sae_training/sparse_autoencoder.py b/sae_training/sparse_autoencoder.py
index 1096e595..b8e2c9de 100644
--- a/sae_training/sparse_autoencoder.py
+++ b/sae_training/sparse_autoencoder.py
@@ -1,4 +1,3 @@
-
"""Most of this is just copied over from Arthur's code and slightly simplified:
https://github.com/ArthurConmy/sae/blob/main/sae/model.py
"""
@@ -21,9 +20,8 @@
class SparseAutoencoder(HookedRootModule):
- """
-
- """
+ """ """
+
def __init__(
self,
cfg,
@@ -44,7 +42,7 @@ def __init__(
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(self.d_in, self.d_sae, dtype=self.dtype, device=self.device)
- )
+ )
)
self.b_enc = nn.Parameter(
torch.zeros(self.d_sae, dtype=self.dtype, device=self.device)
@@ -71,7 +69,7 @@ def __init__(
self.setup() # Required for `HookedRootModule`s
- def forward(self, x, dead_neuron_mask = None):
+ def forward(self, x, dead_neuron_mask=None):
# move x to correct dtype
x = x.to(self.dtype)
sae_in = self.hook_sae_in(
@@ -96,43 +94,44 @@ def forward(self, x, dead_neuron_mask = None):
)
+ self.b_dec
)
-
+
# add config for whether l2 is normalized:
x_centred = x - x.mean(dim=0, keepdim=True)
- mse_loss = (torch.pow((sae_out-x.float()), 2) / (x_centred**2).sum(dim=-1, keepdim=True).sqrt())
-
-
+ mse_loss = (
+ torch.pow((sae_out - x.float()), 2)
+ / (x_centred**2).sum(dim=-1, keepdim=True).sqrt()
+ )
mse_loss_ghost_resid = torch.tensor(0.0, dtype=self.dtype, device=self.device)
# gate on config and training so evals is not slowed down.
if self.cfg.use_ghost_grads and self.training and dead_neuron_mask.sum() > 0:
- assert dead_neuron_mask is not None
-
+ assert dead_neuron_mask is not None
+
# ghost protocol
-
+
# 1.
residual = x - sae_out
l2_norm_residual = torch.norm(residual, dim=-1)
-
+
# 2.
feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
- ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask,:]
- l2_norm_ghost_out = torch.norm(ghost_out, dim = -1)
- norm_scaling_factor = l2_norm_residual / (1e-6+ l2_norm_ghost_out* 2)
- ghost_out = ghost_out*norm_scaling_factor[:, None].detach()
-
- # 3.
+ ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
+ l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
+ norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
+ ghost_out = ghost_out * norm_scaling_factor[:, None].detach()
+
+ # 3.
mse_loss_ghost_resid = (
- torch.pow((ghost_out - residual.detach().float()), 2) / (residual.detach()**2).sum(dim=-1, keepdim=True).sqrt()
+ torch.pow((ghost_out - residual.detach().float()), 2)
+ / (residual.detach() ** 2).sum(dim=-1, keepdim=True).sqrt()
)
mse_rescaling_factor = (mse_loss / (mse_loss_ghost_resid + 1e-6)).detach()
mse_loss_ghost_resid = mse_rescaling_factor * mse_loss_ghost_resid
mse_loss_ghost_resid = mse_loss_ghost_resid.mean()
-
mse_loss = mse_loss.mean()
- sparsity = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
+ sparsity = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
l1_loss = self.l1_coefficient * sparsity
loss = mse_loss + l1_loss + mse_loss_ghost_resid
@@ -140,7 +139,6 @@ def forward(self, x, dead_neuron_mask = None):
@torch.no_grad()
def initialize_b_dec(self, activation_store):
-
if self.cfg.b_dec_init_method == "geometric_median":
self.initialize_b_dec_with_geometric_median(activation_store)
elif self.cfg.b_dec_init_method == "mean":
@@ -148,42 +146,45 @@ def initialize_b_dec(self, activation_store):
elif self.cfg.b_dec_init_method == "zeros":
pass
else:
- raise ValueError(f"Unexpected b_dec_init_method: {self.cfg.b_dec_init_method}")
+ raise ValueError(
+ f"Unexpected b_dec_init_method: {self.cfg.b_dec_init_method}"
+ )
@torch.no_grad()
def initialize_b_dec_with_geometric_median(self, activation_store):
-
previous_b_dec = self.b_dec.clone().cpu()
all_activations = activation_store.storage_buffer.detach().cpu()
out = compute_geometric_median(
- all_activations,
- skip_typechecks=True,
- maxiter=100, per_component=False).median
-
+ all_activations, skip_typechecks=True, maxiter=100, per_component=False
+ ).median
+
previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
distances = torch.norm(all_activations - out, dim=-1)
-
+
print("Reinitializing b_dec with geometric median of activations")
- print(f"Previous distances: {previous_distances.median(0).values.mean().item()}")
+ print(
+ f"Previous distances: {previous_distances.median(0).values.mean().item()}"
+ )
print(f"New distances: {distances.median(0).values.mean().item()}")
-
+
out = torch.tensor(out, dtype=self.dtype, device=self.device)
self.b_dec.data = out
-
+
@torch.no_grad()
def initialize_b_dec_with_mean(self, activation_store):
-
previous_b_dec = self.b_dec.clone().cpu()
all_activations = activation_store.storage_buffer.detach().cpu()
out = all_activations.mean(dim=0)
-
+
previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
distances = torch.norm(all_activations - out, dim=-1)
-
+
print("Reinitializing b_dec with mean of activations")
- print(f"Previous distances: {previous_distances.median(0).values.mean().item()}")
+ print(
+ f"Previous distances: {previous_distances.median(0).values.mean().item()}"
+ )
print(f"New distances: {distances.median(0).values.mean().item()}")
-
+
self.b_dec.data = out.to(self.dtype).to(self.device)
@torch.no_grad()
@@ -193,80 +194,82 @@ def get_test_loss(self, batch_tokens, model):
returns per token loss when activations are substituted in.
"""
head_index = self.cfg.hook_point_head_index
-
+
def standard_replacement_hook(activations, hook):
activations = self.forward(activations)[0].to(activations.dtype)
return activations
-
+
def head_replacement_hook(activations, hook):
- new_actions = self.forward(activations[:,:,head_index])[0].to(activations.dtype)
- activations[:,:,head_index] = new_actions
+ new_actions = self.forward(activations[:, :, head_index])[0].to(
+ activations.dtype
+ )
+ activations[:, :, head_index] = new_actions
return activations
- replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook
-
+ replacement_hook = (
+ standard_replacement_hook if head_index is None else head_replacement_hook
+ )
+
ce_loss_with_recons = model.run_with_hooks(
batch_tokens,
return_type="loss",
fwd_hooks=[(self.cfg.hook_point, replacement_hook)],
)
-
+
return ce_loss_with_recons
@torch.no_grad()
def set_decoder_norm_to_unit_norm(self):
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
-
+
@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
- '''
+ """
Update grads so that they remove the parallel component
(d_sae, d_in) shape
- '''
-
+ """
+
parallel_component = einops.einsum(
self.W_dec.grad,
self.W_dec.data,
"d_sae d_in, d_sae d_in -> d_sae",
)
-
+
self.W_dec.grad -= einops.einsum(
parallel_component,
self.W_dec.data,
"d_sae, d_sae d_in -> d_sae d_in",
)
-
+
def save_model(self, path: str):
- '''
+ """
Basic save function for the model. Saves the model's state_dict and the config used to train it.
- '''
-
+ """
+
# check if path exists
folder = os.path.dirname(path)
os.makedirs(folder, exist_ok=True)
-
- state_dict = {
- "cfg": self.cfg,
- "state_dict": self.state_dict()
- }
-
+
+ state_dict = {"cfg": self.cfg, "state_dict": self.state_dict()}
+
if path.endswith(".pt"):
torch.save(state_dict, path)
elif path.endswith("pkl.gz"):
with gzip.open(path, "wb") as f:
pickle.dump(state_dict, f)
else:
- raise ValueError(f"Unexpected file extension: {path}, supported extensions are .pt and .pkl.gz")
-
-
+ raise ValueError(
+ f"Unexpected file extension: {path}, supported extensions are .pt and .pkl.gz"
+ )
+
print(f"Saved model to {path}")
-
+
@classmethod
def load_from_pretrained(cls, path: str):
- '''
+ """
Load function for the model. Loads the model's state_dict and the config used to train it.
This method can be called directly on the class, without needing an instance.
- '''
+ """
# Ensure the file exists
if not os.path.isfile(path):
@@ -282,25 +285,31 @@ def load_from_pretrained(cls, path: str):
state_dict = torch.load(path)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pt file: {e}")
-
+
elif path.endswith(".pkl.gz"):
try:
- with gzip.open(path, 'rb') as f:
+ with gzip.open(path, "rb") as f:
state_dict = pickle.load(f)
except Exception as e:
- raise IOError(f"Error loading the state dictionary from .pkl.gz file: {e}")
+ raise IOError(
+ f"Error loading the state dictionary from .pkl.gz file: {e}"
+ )
elif path.endswith(".pkl"):
try:
- with open(path, 'rb') as f:
+ with open(path, "rb") as f:
state_dict = pickle.load(f)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pkl file: {e}")
else:
- raise ValueError(f"Unexpected file extension: {path}, supported extensions are .pt, .pkl, and .pkl.gz")
+ raise ValueError(
+ f"Unexpected file extension: {path}, supported extensions are .pt, .pkl, and .pkl.gz"
+ )
# Ensure the loaded state contains both 'cfg' and 'state_dict'
- if 'cfg' not in state_dict or 'state_dict' not in state_dict:
- raise ValueError("The loaded state dictionary must contain 'cfg' and 'state_dict' keys")
+ if "cfg" not in state_dict or "state_dict" not in state_dict:
+ raise ValueError(
+ "The loaded state dictionary must contain 'cfg' and 'state_dict' keys"
+ )
# Create an instance of the class using the loaded configuration
instance = cls(cfg=state_dict["cfg"])
@@ -310,4 +319,4 @@ def load_from_pretrained(cls, path: str):
def get_name(self):
sae_name = f"sparse_autoencoder_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}"
- return sae_name
\ No newline at end of file
+ return sae_name
diff --git a/sae_training/toy_model_runner.py b/sae_training/toy_model_runner.py
index f28fd368..2ac4a36c 100644
--- a/sae_training/toy_model_runner.py
+++ b/sae_training/toy_model_runner.py
@@ -13,7 +13,6 @@
@dataclass
class SAEToyModelRunnerConfig:
-
# ReLu Model Parameters
n_features: int = 5
n_hidden: int = 2
@@ -21,31 +20,33 @@ class SAEToyModelRunnerConfig:
n_anticorrelated_pairs: int = 0
feature_probability: float = 0.025
model_training_steps: int = 10_000
-
+
# SAE Parameters
d_sae: int = 5
-
+
# Training Parameters
l1_coefficient: float = 1e-3
lr: float = 3e-4
- train_batch_size: int = 1024
+ train_batch_size: int = 1024
b_dec_init_method: str = "geometric_median"
-
+
# Sparsity / Dead Feature Handling
- use_ghost_grads: bool = False # not currently implemented, but SAE class expects it.
+ use_ghost_grads: bool = (
+ False # not currently implemented, but SAE class expects it.
+ )
feature_sampling_window: int = 100
- dead_feature_window: int = 100 # unless this window is larger feature sampling,
+ dead_feature_window: int = 100 # unless this window is larger feature sampling,
dead_feature_threshold: float = 1e-8
-
+
# Activation Store Parameters
- total_training_tokens: int = 25_000
-
+ total_training_tokens: int = 25_000
+
# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_toy_model"
wandb_entity: str = None
wandb_log_frequency: int = 50
-
+
# Misc
device: str = "cpu"
seed: int = 42
@@ -55,10 +56,11 @@ class SAEToyModelRunnerConfig:
def __post_init__(self):
self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE
+
def toy_model_sae_runner(cfg):
- '''
+ """
A runner for training an SAE on a toy model.
- '''
+ """
# Toy Model Config
toy_model_cfg = ToyConfig(
n_instances=1, # Not set up to train > 1 SAE so shouldn't do > 1 model.
@@ -86,7 +88,9 @@ def toy_model_sae_runner(cfg):
"batch_size instances features, instances hidden features -> batch_size instances hidden",
)
- sparse_autoencoder = SparseAutoencoder(cfg) # config has the hyperparameters for the SAE
+ sparse_autoencoder = SparseAutoencoder(
+ cfg
+ ) # config has the hyperparameters for the SAE
if cfg.log_to_wandb:
wandb.init(project=cfg.wandb_project, config=cfg)
diff --git a/sae_training/toy_models.py b/sae_training/toy_models.py
index ed540ad8..98f38439 100644
--- a/sae_training/toy_models.py
+++ b/sae_training/toy_models.py
@@ -1,9 +1,9 @@
-'''
+"""
https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab?fbclid=IwAR04OCGu_unvxezvDWkys9_6MJPEnXuu6GSqU6ScrMkAb1bvdSYFOeS35AY
https://github.com/callummcdougall/sae-exercises-mats?fbclid=IwAR3qYAELbyD_x5IAYN4yCDFQzxXHeuH6CwMi_E7g4Qg6G1QXRNAYabQ4xGs
-'''
+"""
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union
@@ -24,15 +24,19 @@
device = "cpu"
+
def linear_lr(step, steps):
- return (1 - (step / steps))
+ return 1 - (step / steps)
+
def constant_lr(*_):
return 1.0
+
def cosine_decay_lr(step, steps):
return np.cos(0.5 * np.pi * step / (steps - 1))
+
@dataclass
class Config:
# We optimize n_instances models in a single training loop to let us sweep over
@@ -44,9 +48,10 @@ class Config:
n_correlated_pairs: int = 0
n_anticorrelated_pairs: int = 0
+
class Model(nn.Module):
- W: Float[Tensor, "n_instances n_hidden n_features"]
- b_final: Float[Tensor, "n_instances n_features"]
+ W: Float[Tensor, "n_instances n_hidden n_features"] # noqa
+ b_final: Float[Tensor, "n_instances n_features"] # noqa
# Our linear map is x -> ReLU(W.T @ W @ x + b_final)
def __init__(
@@ -54,38 +59,49 @@ def __init__(
cfg: Config,
feature_probability: Optional[Union[float, Tensor]] = None,
importance: Optional[Union[float, Tensor]] = None,
- device = device,
+ device=device,
):
super().__init__()
self.cfg = cfg
- if feature_probability is None: feature_probability = t.ones(())
- if isinstance(feature_probability, float): feature_probability = t.tensor(feature_probability)
- self.feature_probability = feature_probability.to(device).broadcast_to((cfg.n_instances, cfg.n_features))
- if importance is None: importance = t.ones(())
- if isinstance(importance, float): importance = t.tensor(importance)
- self.importance = importance.to(device).broadcast_to((cfg.n_instances, cfg.n_features))
+ if feature_probability is None:
+ feature_probability = t.ones(())
+ if isinstance(feature_probability, float):
+ feature_probability = t.tensor(feature_probability)
+ self.feature_probability = feature_probability.to(device).broadcast_to(
+ (cfg.n_instances, cfg.n_features)
+ )
+ if importance is None:
+ importance = t.ones(())
+ if isinstance(importance, float):
+ importance = t.tensor(importance)
+ self.importance = importance.to(device).broadcast_to(
+ (cfg.n_instances, cfg.n_features)
+ )
- self.W = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_instances, cfg.n_hidden, cfg.n_features))))
+ self.W = nn.Parameter(
+ nn.init.xavier_normal_(
+ t.empty((cfg.n_instances, cfg.n_hidden, cfg.n_features))
+ )
+ )
self.b_final = nn.Parameter(t.zeros((cfg.n_instances, cfg.n_features)))
self.to(device)
-
def forward(
- self,
- features: Float[Tensor, "... instances features"]
- ) -> Float[Tensor, "... instances features"]:
+ self, features: Float[Tensor, "... instances features"] # noqa
+ ) -> Float[Tensor, "... instances features"]: # noqa
hidden = einops.einsum(
- features, self.W,
- "... instances features, instances hidden features -> ... instances hidden"
+ features,
+ self.W,
+ "... instances features, instances hidden features -> ... instances hidden",
)
out = einops.einsum(
- hidden, self.W,
- "... instances hidden, instances hidden features -> ... instances features"
+ hidden,
+ self.W,
+ "... instances hidden, instances hidden features -> ... instances features",
)
return F.relu(out + self.b_final)
-
# def generate_batch(self, batch_size) -> Float[Tensor, "batch_size instances features"]:
# '''
# Generates a batch of data. We'll return to this function later when we apply correlations.
@@ -100,50 +116,104 @@ def forward(
# )
# return batch
- def generate_correlated_features(self, batch_size, n_correlated_pairs) -> Float[Tensor, "batch_size instances features"]:
- '''
+ def generate_correlated_features(
+ self, batch_size, n_correlated_pairs
+ ) -> Float[Tensor, "batch_size instances features"]: # noqa
+ """
Generates a batch of correlated features.
Each output[i, j, 2k] and output[i, j, 2k + 1] are correlated, i.e. one is present iff the other is present.
- '''
- feat = t.rand((batch_size, self.cfg.n_instances, 2 * n_correlated_pairs), device=self.W.device)
- feat_set_seeds = t.rand((batch_size, self.cfg.n_instances, n_correlated_pairs), device=self.W.device)
+ """
+ feat = t.rand(
+ (batch_size, self.cfg.n_instances, 2 * n_correlated_pairs),
+ device=self.W.device,
+ )
+ feat_set_seeds = t.rand(
+ (batch_size, self.cfg.n_instances, n_correlated_pairs), device=self.W.device
+ )
feat_set_is_present = feat_set_seeds <= self.feature_probability[:, [0]]
- feat_is_present = einops.repeat(feat_set_is_present, "batch instances features -> batch instances (features pair)", pair=2)
+ feat_is_present = einops.repeat(
+ feat_set_is_present,
+ "batch instances features -> batch instances (features pair)",
+ pair=2,
+ )
return t.where(feat_is_present, feat, 0.0)
- def generate_anticorrelated_features(self, batch_size, n_anticorrelated_pairs) -> Float[Tensor, "batch_size instances features"]:
- '''
+ def generate_anticorrelated_features(
+ self, batch_size, n_anticorrelated_pairs
+ ) -> Float[Tensor, "batch_size instances features"]: # noqa
+ """
Generates a batch of anti-correlated features.
Each output[i, j, 2k] and output[i, j, 2k + 1] are anti-correlated, i.e. one is present iff the other is absent.
- '''
- feat = t.rand((batch_size, self.cfg.n_instances, 2 * n_anticorrelated_pairs), device=self.W.device)
- feat_set_seeds = t.rand((batch_size, self.cfg.n_instances, n_anticorrelated_pairs), device=self.W.device)
- first_feat_seeds = t.rand((batch_size, self.cfg.n_instances, n_anticorrelated_pairs), device=self.W.device)
+ """
+ feat = t.rand(
+ (batch_size, self.cfg.n_instances, 2 * n_anticorrelated_pairs),
+ device=self.W.device,
+ )
+ feat_set_seeds = t.rand(
+ (batch_size, self.cfg.n_instances, n_anticorrelated_pairs),
+ device=self.W.device,
+ )
+ first_feat_seeds = t.rand(
+ (batch_size, self.cfg.n_instances, n_anticorrelated_pairs),
+ device=self.W.device,
+ )
feat_set_is_present = feat_set_seeds <= 2 * self.feature_probability[:, [0]]
first_feat_is_present = first_feat_seeds <= 0.5
- first_feats = t.where(feat_set_is_present & first_feat_is_present, feat[:, :, :n_anticorrelated_pairs], 0.0)
- second_feats = t.where(feat_set_is_present & (~first_feat_is_present), feat[:, :, n_anticorrelated_pairs:], 0.0)
- return einops.rearrange(t.concat([first_feats, second_feats], dim=-1), "batch instances (pair features) -> batch instances (features pair)", pair=2)
-
- def generate_uncorrelated_features(self, batch_size, n_uncorrelated) -> Float[Tensor, "batch_size instances features"]:
- '''
+ first_feats = t.where(
+ feat_set_is_present & first_feat_is_present,
+ feat[:, :, :n_anticorrelated_pairs],
+ 0.0,
+ )
+ second_feats = t.where(
+ feat_set_is_present & (~first_feat_is_present),
+ feat[:, :, n_anticorrelated_pairs:],
+ 0.0,
+ )
+ return einops.rearrange(
+ t.concat([first_feats, second_feats], dim=-1),
+ "batch instances (pair features) -> batch instances (features pair)",
+ pair=2,
+ )
+
+ def generate_uncorrelated_features(
+ self, batch_size, n_uncorrelated
+ ) -> Float[Tensor, "batch_size instances features"]: # noqa
+ """
Generates a batch of uncorrelated features.
- '''
- feat = t.rand((batch_size, self.cfg.n_instances, n_uncorrelated), device=self.W.device)
- feat_seeds = t.rand((batch_size, self.cfg.n_instances, n_uncorrelated), device=self.W.device)
+ """
+ feat = t.rand(
+ (batch_size, self.cfg.n_instances, n_uncorrelated), device=self.W.device
+ )
+ feat_seeds = t.rand(
+ (batch_size, self.cfg.n_instances, n_uncorrelated), device=self.W.device
+ )
feat_is_present = feat_seeds <= self.feature_probability[:, [0]]
return t.where(feat_is_present, feat, 0.0)
-
- def generate_batch(self, batch_size) -> Float[Tensor, "batch_size instances features"]:
- '''
+
+ def generate_batch(
+ self, batch_size
+ ) -> Float[Tensor, "batch_size instances features"]: # noqa
+ """
Generates a batch of data, with optional correlated & anticorrelated features.
- '''
- n_uncorrelated = self.cfg.n_features - 2 * self.cfg.n_correlated_pairs - 2 * self.cfg.n_anticorrelated_pairs
+ """
+ n_uncorrelated = (
+ self.cfg.n_features
+ - 2 * self.cfg.n_correlated_pairs
+ - 2 * self.cfg.n_anticorrelated_pairs
+ )
data = []
if self.cfg.n_correlated_pairs > 0:
- data.append(self.generate_correlated_features(batch_size, self.cfg.n_correlated_pairs))
+ data.append(
+ self.generate_correlated_features(
+ batch_size, self.cfg.n_correlated_pairs
+ )
+ )
if self.cfg.n_anticorrelated_pairs > 0:
- data.append(self.generate_anticorrelated_features(batch_size, self.cfg.n_anticorrelated_pairs))
+ data.append(
+ self.generate_anticorrelated_features(
+ batch_size, self.cfg.n_anticorrelated_pairs
+ )
+ )
if n_uncorrelated > 0:
data.append(self.generate_uncorrelated_features(batch_size, n_uncorrelated))
batch = t.cat(data, dim=-1)
@@ -151,21 +221,22 @@ def generate_batch(self, batch_size) -> Float[Tensor, "batch_size instances feat
def calculate_loss(
self,
- out: Float[Tensor, "batch instances features"],
- batch: Float[Tensor, "batch instances features"],
- ) -> Float[Tensor, ""]:
- '''
+ out: Float[Tensor, "batch instances features"], # noqa
+ batch: Float[Tensor, "batch instances features"], # noqa
+ ) -> Float[Tensor, ""]: # noqa
+ """
Calculates the loss for a given batch, using this loss described in the Toy Models paper:
https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss
Note, `model.importance` is guaranteed to broadcast with the shape of `out` and `batch`.
- '''
+ """
error = self.importance * ((batch - out) ** 2)
- loss = einops.reduce(error, 'batch instances features -> instances', 'mean').sum()
+ loss = einops.reduce(
+ error, "batch instances features -> instances", "mean"
+ ).sum()
return loss
-
def optimize(
self,
batch_size: int = 1024,
@@ -174,20 +245,19 @@ def optimize(
lr: float = 1e-3,
lr_scale: Callable[[int, int], float] = constant_lr,
):
- '''
+ """
Optimizes the model using the given hyperparameters.
- '''
+ """
optimizer = t.optim.Adam(list(self.parameters()), lr=lr)
progress_bar = tqdm(range(steps), desc="Training Toy Model")
-
- for step in progress_bar:
+ for step in progress_bar:
# Update learning rate
step_lr = lr * lr_scale(step, steps)
for group in optimizer.param_groups:
- group['lr'] = step_lr
-
+ group["lr"] = step_lr
+
# Optimize
optimizer.zero_grad()
batch = self.generate_batch(batch_size)
@@ -198,48 +268,57 @@ def optimize(
# Display progress bar
if step % log_freq == 0 or (step + 1 == steps):
- progress_bar.set_postfix(loss=loss.item()/self.cfg.n_instances, lr=step_lr)
+ progress_bar.set_postfix(
+ loss=loss.item() / self.cfg.n_instances, lr=step_lr
+ )
Arr = np.ndarray
def plot_features_in_2d(
- values: Float[Tensor, "timesteps instances d_hidden feats"],
- colors = None, # shape [timesteps instances feats]
+ values: Float[Tensor, "timesteps instances d_hidden feats"], # noqa
+ colors=None, # shape [timesteps instances feats]
title: Optional[str] = None,
subplot_titles: Optional[List[str]] = None,
save: Optional[str] = None,
colab: bool = False,
):
- '''
+ """
Visualises superposition in 2D.
If values is 4D, the first dimension is assumed to be timesteps, and an animation is created.
- '''
+ """
# Convert values to 4D for consistency
if values.ndim == 3:
values = values.unsqueeze(0)
values = values.transpose(-1, -2)
-
+
# Get dimensions
n_timesteps, n_instances, n_features, _ = values.shape
# If we have a large number of features per plot (i.e. we're plotting projections of data) then use smaller lines
linewidth, markersize = (1, 4) if (n_features >= 25) else (2, 10)
-
+
# Convert colors to 3D, if it's 2D (i.e. same colors for all instances)
if isinstance(colors, list) and isinstance(colors[0], str):
colors = [colors for _ in range(n_instances)]
# Convert colors to something which has 4D, if it is 3D (i.e. same colors for all timesteps)
- if any([
- colors is None,
- isinstance(colors, list) and isinstance(colors[0], list) and isinstance(colors[0][0], str),
- (isinstance(colors, Tensor) or isinstance(colors, Arr)) and colors.ndim == 3,
- ]):
+ if any(
+ [
+ colors is None,
+ isinstance(colors, list)
+ and isinstance(colors[0], list)
+ and isinstance(colors[0][0], str),
+ (isinstance(colors, Tensor) or isinstance(colors, Arr))
+ and colors.ndim == 3,
+ ]
+ ):
colors = [colors for _ in range(values.shape[0])]
# Now that colors has length `timesteps` in some sense, we can convert it to lists of strings
- colors = [parse_colors_for_superposition_plot(c, n_instances, n_features) for c in colors]
+ colors = [
+ parse_colors_for_superposition_plot(c, n_instances, n_features) for c in colors
+ ]
# Same for subplot titles & titles
if subplot_titles is not None:
@@ -253,7 +332,7 @@ def plot_features_in_2d(
fig, axs = plt.subplots(1, n_instances, figsize=(5 * n_instances, 5))
if n_instances == 1:
axs = [axs]
-
+
# If there are titles, add more spacing for them
fig.subplots_adjust(bottom=0.2, top=0.9, left=0.1, right=0.9)
if title:
@@ -264,12 +343,20 @@ def plot_features_in_2d(
for instance_idx, ax in enumerate(axs):
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
- ax.set_aspect('equal', adjustable='box')
+ ax.set_aspect("equal", adjustable="box")
instance_lines = []
instance_markers = []
for feature_idx in range(n_features):
- line, = ax.plot([], [], color=colors[0][instance_idx][feature_idx], lw=linewidth)
- marker, = ax.plot([], [], color=colors[0][instance_idx][feature_idx], marker='o', markersize=markersize)
+ (line,) = ax.plot(
+ [], [], color=colors[0][instance_idx][feature_idx], lw=linewidth
+ )
+ (marker,) = ax.plot(
+ [],
+ [],
+ color=colors[0][instance_idx][feature_idx],
+ marker="o",
+ markersize=markersize,
+ )
instance_lines.append(line)
instance_markers.append(marker)
lines.append(instance_lines)
@@ -280,19 +367,26 @@ def update(val):
# It works if I use t = int(val), so long as I put something like X = slider.val first. Idk why!
if n_timesteps > 1:
_ = slider.val
- t = int(val)
+ t = int(val)
for instance_idx in range(n_instances):
for feature_idx in range(n_features):
x, y = values[t, instance_idx, feature_idx].tolist()
lines[instance_idx][feature_idx].set_data([0, x], [0, y])
markers[instance_idx][feature_idx].set_data(x, y)
- lines[instance_idx][feature_idx].set_color(colors[t][instance_idx][feature_idx])
- markers[instance_idx][feature_idx].set_color(colors[t][instance_idx][feature_idx])
+ lines[instance_idx][feature_idx].set_color(
+ colors[t][instance_idx][feature_idx]
+ )
+ markers[instance_idx][feature_idx].set_color(
+ colors[t][instance_idx][feature_idx]
+ )
if title:
fig.suptitle(title[t], fontsize=15)
if subplot_titles:
- axs[instance_idx].set_title(subplot_titles[t][instance_idx], fontsize=12)
+ axs[instance_idx].set_title(
+ subplot_titles[t][instance_idx], fontsize=12
+ )
fig.canvas.draw_idle()
+
def play(event):
_ = slider.val
for i in range(n_timesteps):
@@ -303,8 +397,10 @@ def play(event):
if n_timesteps > 1:
# Create the slider
- ax_slider = plt.axes([0.15, 0.05, 0.7, 0.05], facecolor='lightgray')
- slider = Slider(ax_slider, 'Time', 0, n_timesteps - 1, valinit=0, valfmt='%1.0f')
+ ax_slider = plt.axes([0.15, 0.05, 0.7, 0.05], facecolor="lightgray")
+ slider = Slider(
+ ax_slider, "Time", 0, n_timesteps - 1, valinit=0, valfmt="%1.0f"
+ )
# Create the play button
# ax_button = plt.axes([0.8, 0.05, 0.08, 0.05], facecolor='lightgray')
@@ -321,25 +417,28 @@ def play(event):
# Save
if isinstance(save, str):
- ani = FuncAnimation(fig, update, frames=n_timesteps, interval=0.04, repeat=False)
- ani.save(save, writer='pillow', fps=25)
+ ani = FuncAnimation(
+ fig, update, frames=n_timesteps, interval=0.04, repeat=False
+ )
+ ani.save(save, writer="pillow", fps=25)
elif colab:
- ani = FuncAnimation(fig, update, frames=n_timesteps, interval=0.04, repeat=False)
+ ani = FuncAnimation(
+ fig, update, frames=n_timesteps, interval=0.04, repeat=False
+ )
clear_output()
return ani
-
def parse_colors_for_superposition_plot(
- colors: Optional[Union[Tuple[int, int], Float[Tensor, "instances feats"]]],
+ colors: Optional[Union[Tuple[int, int], Float[Tensor, "instances feats"]]], # noqa
n_instances: int,
n_feats: int,
) -> List[List[str]]:
- '''
+ """
There are lots of different ways colors can be represented in the superposition plot.
-
+
This function unifies them all by turning colors into a list of lists of strings, i.e. one color for each instance & feature.
- '''
+ """
# If colors is a tensor, we assume it's the importances tensor, and we color according to a viridis color scheme
# if isinstance(colors, Tensor):
# colors = t.broadcast_to(colors, (n_instances, n_feats))
@@ -347,24 +446,24 @@ def parse_colors_for_superposition_plot(
# [helper_get_viridis(v.item()) for v in colors_for_this_instance]
# for colors_for_this_instance in colors
# ]
-
+
# If colors is a tuple of ints, it's interpreted as number of correlated / anticorrelated pairs
if isinstance(colors, tuple):
n_corr, n_anti = colors
n_indep = n_feats - 2 * (n_corr - n_anti)
colors = [
- ["blue", "blue", "limegreen", "limegreen", "purple", "purple"][:n_corr*2] + ["red", "red", "orange", "orange", "brown", "brown"][:n_anti*2] + ["black"] * n_indep
+ ["blue", "blue", "limegreen", "limegreen", "purple", "purple"][: n_corr * 2]
+ + ["red", "red", "orange", "orange", "brown", "brown"][: n_anti * 2]
+ + ["black"] * n_indep
for _ in range(n_instances)
]
-
+
# If colors is a string, make all datapoints that color
elif isinstance(colors, str):
colors = [[colors] * n_feats] * n_instances
-
+
# Lastly, if colors is None, make all datapoints black
elif colors is None:
colors = [["black"] * n_feats] * n_instances
-
- return colors
-
+ return colors
diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py
index ecf3c702..5a3107a7 100644
--- a/sae_training/train_sae_on_language_model.py
+++ b/sae_training/train_sae_on_language_model.py
@@ -1,5 +1,3 @@
-
-
import torch
from torch.optim import Adam
from tqdm import tqdm
@@ -23,32 +21,35 @@ def train_sae_on_language_model(
use_wandb: bool = False,
wandb_log_frequency: int = 50,
):
-
total_training_tokens = sparse_autoencoder.cfg.total_training_tokens
total_training_steps = total_training_tokens // batch_size
n_training_steps = 0
n_training_tokens = 0
-
+
if n_checkpoints > 0:
- checkpoint_thresholds = list(range(0, total_training_tokens, total_training_tokens // n_checkpoints))[1:]
-
+ checkpoint_thresholds = list(
+ range(0, total_training_tokens, total_training_tokens // n_checkpoints)
+ )[1:]
+
# track active features
- act_freq_scores = torch.zeros(sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device)
- n_forward_passes_since_fired = torch.zeros(sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device)
+ act_freq_scores = torch.zeros(
+ sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device
+ )
+ n_forward_passes_since_fired = torch.zeros(
+ sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device
+ )
n_frac_active_tokens = 0
-
- optimizer = Adam(sparse_autoencoder.parameters(),
- lr = sparse_autoencoder.cfg.lr)
+
+ optimizer = Adam(sparse_autoencoder.parameters(), lr=sparse_autoencoder.cfg.lr)
scheduler = get_scheduler(
sparse_autoencoder.cfg.lr_scheduler_name,
optimizer=optimizer,
- warm_up_steps = sparse_autoencoder.cfg.lr_warm_up_steps,
+ warm_up_steps=sparse_autoencoder.cfg.lr_warm_up_steps,
training_steps=total_training_steps,
- lr_end=sparse_autoencoder.cfg.lr / 10, # heuristic for now.
+ lr_end=sparse_autoencoder.cfg.lr / 10, # heuristic for now.
)
sparse_autoencoder.initialize_b_dec(activation_store)
sparse_autoencoder.train()
-
pbar = tqdm(total=total_training_tokens, desc="Training SAE")
while n_training_tokens < total_training_tokens:
@@ -59,42 +60,52 @@ def train_sae_on_language_model(
# after resampling, reset the sparsity:
if (n_training_steps + 1) % feature_sampling_window == 0:
-
feature_sparsity = act_freq_scores / n_frac_active_tokens
log_feature_sparsity = torch.log10(feature_sparsity + 1e-10).detach().cpu()
if use_wandb:
wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy())
wandb.log(
- {
+ {
"metrics/mean_log10_feature_sparsity": log_feature_sparsity.mean().item(),
"plots/feature_density_line_chart": wandb_histogram,
},
step=n_training_steps,
)
-
- act_freq_scores = torch.zeros(sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device)
+
+ act_freq_scores = torch.zeros(
+ sparse_autoencoder.cfg.d_sae, device=sparse_autoencoder.cfg.device
+ )
n_frac_active_tokens = 0
scheduler.step()
scheduler.step()
-
+
scheduler.step()
-
+
optimizer.zero_grad()
-
- ghost_grad_neuron_mask = (n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window).bool()
+
+ ghost_grad_neuron_mask = (
+ n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window
+ ).bool()
sae_in = activation_store.next_batch()
-
+
# Forward and Backward Passes
- sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss = sparse_autoencoder(
+ (
+ sae_out,
+ feature_acts,
+ loss,
+ mse_loss,
+ l1_loss,
+ ghost_grad_loss,
+ ) = sparse_autoencoder(
sae_in,
ghost_grad_neuron_mask,
)
- did_fire = ((feature_acts > 0).float().sum(-2) > 0)
+ did_fire = (feature_acts > 0).float().sum(-2) > 0
n_forward_passes_since_fired += 1
n_forward_passes_since_fired[did_fire] = 0
-
+
n_training_tokens += batch_size
with torch.no_grad():
@@ -107,16 +118,17 @@ def train_sae_on_language_model(
# metrics for currents acts
l0 = (feature_acts > 0).float().sum(-1).mean()
current_learning_rate = optimizer.param_groups[0]["lr"]
-
+
per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze()
- total_variance = (sae_in-sae_in.mean(0)).pow(2).sum(-1)
- explained_variance = 1 - per_token_l2_loss/total_variance
-
+ total_variance = (sae_in - sae_in.mean(0)).pow(2).sum(-1)
+ explained_variance = 1 - per_token_l2_loss / total_variance
+
wandb.log(
{
# losses
"losses/mse_loss": mse_loss.item(),
- "losses/l1_loss": l1_loss.item() / sparse_autoencoder.l1_coefficient, # normalize by l1 coefficient
+ "losses/l1_loss": l1_loss.item()
+ / sparse_autoencoder.l1_coefficient, # normalize by l1 coefficient
"losses/ghost_grad_loss": ghost_grad_loss.item(),
"losses/overall_loss": loss.item(),
# variance explained
@@ -151,7 +163,7 @@ def train_sae_on_language_model(
sparse_autoencoder.eval()
run_evals(sparse_autoencoder, activation_store, model, n_training_steps)
sparse_autoencoder.train()
-
+
pbar.set_description(
f"{n_training_steps}| MSE Loss {mse_loss.item():.3f} | L1 {l1_loss.item():.3f}"
)
@@ -161,7 +173,6 @@ def train_sae_on_language_model(
sparse_autoencoder.remove_gradient_parallel_to_decoder_directions()
optimizer.step()
-
# checkpoint if at checkpoint frequency
if n_checkpoints > 0 and n_training_tokens > checkpoint_thresholds[0]:
cfg = sparse_autoencoder.cfg
@@ -175,29 +186,33 @@ def train_sae_on_language_model(
n_checkpoints = 0
if cfg.log_to_wandb:
model_artifact = wandb.Artifact(
- f"{sparse_autoencoder.get_name()}", type="model", metadata=dict(cfg.__dict__)
+ f"{sparse_autoencoder.get_name()}",
+ type="model",
+ metadata=dict(cfg.__dict__),
)
model_artifact.add_file(path)
wandb.log_artifact(model_artifact)
-
+
sparsity_artifact = wandb.Artifact(
- f"{sparse_autoencoder.get_name()}_log_feature_sparsity", type="log_feature_sparsity", metadata=dict(cfg.__dict__)
+ f"{sparse_autoencoder.get_name()}_log_feature_sparsity",
+ type="log_feature_sparsity",
+ metadata=dict(cfg.__dict__),
)
sparsity_artifact.add_file(log_feature_sparsity_path)
wandb.log_artifact(sparsity_artifact)
-
-
+
n_training_steps += 1
-
+
log_feature_sparsity_path = f"{sparse_autoencoder.cfg.checkpoint_path}/final_{sparse_autoencoder.get_name()}_log_feature_sparsity.pt"
sparse_autoencoder.save_model(path)
torch.save(log_feature_sparsity, log_feature_sparsity_path)
if cfg.log_to_wandb:
sparsity_artifact = wandb.Artifact(
- f"{sparse_autoencoder.get_name()}_log_feature_sparsity", type="log_feature_sparsity", metadata=dict(cfg.__dict__)
- )
+ f"{sparse_autoencoder.get_name()}_log_feature_sparsity",
+ type="log_feature_sparsity",
+ metadata=dict(cfg.__dict__),
+ )
sparsity_artifact.add_file(log_feature_sparsity_path)
wandb.log_artifact(sparsity_artifact)
-
return sparse_autoencoder
diff --git a/sae_training/train_sae_on_toy_model.py b/sae_training/train_sae_on_toy_model.py
index 7ad3c29e..522b151d 100644
--- a/sae_training/train_sae_on_toy_model.py
+++ b/sae_training/train_sae_on_toy_model.py
@@ -32,7 +32,6 @@ def train_toy_sae(
pbar = tqdm(dataloader, desc="Training SAE")
for _, batch in enumerate(pbar):
-
batch = next(dataloader)
# Make sure the W_dec is still zero-norm
sparse_autoencoder.set_decoder_norm_to_unit_norm()
@@ -43,12 +42,10 @@ def train_toy_sae(
loss.backward()
sparse_autoencoder.remove_gradient_parallel_to_decoder_directions()
optimizer.step()
-
+
n_training_tokens += batch_size
with torch.no_grad():
-
-
# Calculate the sparsities, and add it to a list
act_freq_scores = (feature_acts.abs() > 0).float().sum(0)
frac_active_list.append(act_freq_scores)
@@ -67,15 +64,13 @@ def train_toy_sae(
len(frac_active_list) * batch_size
)
-
l0 = (feature_acts > 0).float().sum(-1).mean()
current_learning_rate = optimizer.param_groups[0]["lr"]
l2_norm = torch.norm(feature_acts, dim=1).mean()
-
+
l2_norm_in = torch.norm(batch, dim=-1)
l2_norm_out = torch.norm(sae_out, dim=-1)
- l2_norm_ratio = l2_norm_out / (1e-6+l2_norm_in)
-
+ l2_norm_ratio = l2_norm_out / (1e-6 + l2_norm_in)
if use_wandb and ((n_training_steps + 1) % wandb_log_frequency == 0):
wandb.log(
@@ -104,7 +99,7 @@ def train_toy_sae(
},
step=n_training_steps,
)
-
+
if (n_training_steps + 1) % (wandb_log_frequency * 100) == 0:
log_feature_sparsity = torch.log10(feature_sparsity + 1e-8)
wandb.log(
@@ -121,7 +116,6 @@ def train_toy_sae(
)
pbar.update(batch_size)
-
# If we did checkpointing we'd do it here.
n_training_steps += 1
diff --git a/sae_training/utils.py b/sae_training/utils.py
index e3af6bc1..c2633a56 100644
--- a/sae_training/utils.py
+++ b/sae_training/utils.py
@@ -8,7 +8,7 @@
from sae_training.sparse_autoencoder import SparseAutoencoder
-class LMSparseAutoencoderSessionloader():
+class LMSparseAutoencoderSessionloader:
"""
Responsible for loading all required
artifacts and files for training
@@ -18,25 +18,28 @@ class LMSparseAutoencoderSessionloader():
def __init__(self, cfg: LanguageModelSAERunnerConfig):
self.cfg = cfg
-
-
- def load_session(self) -> Tuple[HookedTransformer, SparseAutoencoder, ActivationsStore]:
- '''
+
+ def load_session(
+ self,
+ ) -> Tuple[HookedTransformer, SparseAutoencoder, ActivationsStore]:
+ """
Loads a session for training a sparse autoencoder on a language model.
- '''
-
+ """
+
model = self.get_model(self.cfg.model_name)
model.to(self.cfg.device)
activations_loader = self.get_activations_loader(self.cfg, model)
sparse_autoencoder = self.initialize_sparse_autoencoder(self.cfg)
-
+
return model, sparse_autoencoder, activations_loader
-
+
@classmethod
- def load_session_from_pretrained(cls, path: str) -> Tuple[HookedTransformer, SparseAutoencoder, ActivationsStore]:
- '''
+ def load_session_from_pretrained(
+ cls, path: str
+ ) -> Tuple[HookedTransformer, SparseAutoencoder, ActivationsStore]:
+ """
Loads a session for analysing a pretrained sparse autoencoder.
- '''
+ """
if torch.backends.mps.is_available():
cfg = torch.load(path, map_location="mps")["cfg"]
cfg.device = "mps"
@@ -47,61 +50,68 @@ def load_session_from_pretrained(cls, path: str) -> Tuple[HookedTransformer, Spa
model, _, activations_loader = cls(cfg).load_session()
sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)
-
+
return model, sparse_autoencoder, activations_loader
-
+
def get_model(self, model_name: str):
- '''
+ """
Loads a model from transformer lens
- '''
-
+ """
+
# Todo: add check that model_name is valid
-
+
model = HookedTransformer.from_pretrained(model_name)
-
- return model
-
+
+ return model
+
def initialize_sparse_autoencoder(self, cfg: LanguageModelSAERunnerConfig):
- '''
+ """
Initializes a sparse autoencoder
- '''
-
+ """
+
sparse_autoencoder = SparseAutoencoder(cfg)
-
+
return sparse_autoencoder
-
- def get_activations_loader(self, cfg: LanguageModelSAERunnerConfig, model: HookedTransformer):
- '''
+
+ def get_activations_loader(
+ self, cfg: LanguageModelSAERunnerConfig, model: HookedTransformer
+ ):
+ """
Loads a DataLoaderBuffer for the activations of a language model.
- '''
-
+ """
+
activations_loader = ActivationsStore(
- cfg, model,
+ cfg,
+ model,
)
-
+
return activations_loader
+
def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[int, int]):
"""
Shuffles two buffers on disk.
"""
- assert buffer_idx_range[0] < buffer_idx_range[1], \
- "buffer_idx_range[0] must be smaller than buffer_idx_range[1]"
-
+ assert (
+ buffer_idx_range[0] < buffer_idx_range[1]
+ ), "buffer_idx_range[0] must be smaller than buffer_idx_range[1]"
+
buffer_idx1 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()
buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()
- while buffer_idx1 == buffer_idx2: # Make sure they're not the same
- buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()
-
+ while buffer_idx1 == buffer_idx2: # Make sure they're not the same
+ buffer_idx2 = torch.randint(
+ buffer_idx_range[0], buffer_idx_range[1], (1,)
+ ).item()
+
buffer1 = torch.load(f"{datapath}/{buffer_idx1}.pt")
buffer2 = torch.load(f"{datapath}/{buffer_idx2}.pt")
joint_buffer = torch.cat([buffer1, buffer2])
-
+
# Shuffle them
joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])]
- shuffled_buffer1 = joint_buffer[:buffer1.shape[0]]
- shuffled_buffer2 = joint_buffer[buffer1.shape[0]:]
-
+ shuffled_buffer1 = joint_buffer[: buffer1.shape[0]]
+ shuffled_buffer2 = joint_buffer[buffer1.shape[0] :]
+
# Save them back
torch.save(shuffled_buffer1, f"{datapath}/{buffer_idx1}.pt")
torch.save(shuffled_buffer2, f"{datapath}/{buffer_idx2}.pt")
diff --git a/scripts/generate_dashboards.py b/scripts/generate_dashboards.py
index ae8a4a7d..4e8e450e 100644
--- a/scripts/generate_dashboards.py
+++ b/scripts/generate_dashboards.py
@@ -22,35 +22,28 @@
from sae_training.utils import LMSparseAutoencoderSessionloader
-class DashboardRunner():
-
+class DashboardRunner:
def __init__(
self,
sae_path: str = None,
dashboard_parent_folder: str = "./feature_dashboards",
wandb_artifact_path: str = None,
init_session: bool = True,
-
# token pars
n_batches_to_sample_from: int = 2**12,
- n_prompts_to_select: int = 4096*6,
-
+ n_prompts_to_select: int = 4096 * 6,
# sampling pars
n_features_at_a_time: int = 1024,
max_batch_size: int = 256,
buffer_tokens: int = 8,
-
# util pars
use_wandb: bool = False,
continue_existing_dashboard: bool = True,
final_index: int = None,
):
- '''
-
- '''
-
- if wandb_artifact_path is not None:
+ """ """
+ if wandb_artifact_path is not None:
artifact_dir = f"artifacts/{wandb_artifact_path.split('/')[2]}"
if not os.path.exists(artifact_dir):
print("Downloading artifact")
@@ -59,92 +52,106 @@ def __init__(
artifact_dir = artifact.download()
path_to_artifact = f"{artifact_dir}/{os.listdir(artifact_dir)[0]}"
# feature sparsity
- feature_sparsity_path = self.get_feature_sparsity_path(wandb_artifact_path)
+ feature_sparsity_path = self.get_feature_sparsity_path(
+ wandb_artifact_path
+ )
artifact = run.use_artifact(feature_sparsity_path)
artifact_dir = artifact.download()
# add it as a property
- self.feature_sparsity = torch.load(f"{artifact_dir}/{os.listdir(artifact_dir)[0]}")
+ self.feature_sparsity = torch.load(
+ f"{artifact_dir}/{os.listdir(artifact_dir)[0]}"
+ )
else:
print("Artifact already downloaded")
path_to_artifact = f"{artifact_dir}/{os.listdir(artifact_dir)[0]}"
-
- feature_sparsity_path = self.get_feature_sparsity_path(wandb_artifact_path)
+
+ feature_sparsity_path = self.get_feature_sparsity_path(
+ wandb_artifact_path
+ )
artifact_dir = f"artifacts/{feature_sparsity_path.split('/')[2]}"
feature_sparsity_file = os.listdir(artifact_dir)[0]
- self.feature_sparsity = torch.load(f"{artifact_dir}/{feature_sparsity_file}")
-
+ self.feature_sparsity = torch.load(
+ f"{artifact_dir}/{feature_sparsity_file}"
+ )
+
self.sae_path = path_to_artifact
- else:
+ else:
assert sae_path is not None
self.sae_path = sae_path
-
+
if init_session:
self.init_sae_session()
-
+
self.n_features_at_a_time = n_features_at_a_time
self.max_batch_size = max_batch_size
self.buffer_tokens = buffer_tokens
self.use_wandb = use_wandb
- self.final_index = final_index if final_index is not None else self.sparse_autoencoder.cfg.d_sae
+ self.final_index = (
+ final_index
+ if final_index is not None
+ else self.sparse_autoencoder.cfg.d_sae
+ )
self.n_batches_to_sample_from = n_batches_to_sample_from
self.n_prompts_to_select = n_prompts_to_select
-
-
+
# Deal with file structure
if not os.path.exists(dashboard_parent_folder):
os.makedirs(dashboard_parent_folder)
- self.dashboard_folder = f"{dashboard_parent_folder}/{self.get_dashboard_folder_name()}"
+ self.dashboard_folder = (
+ f"{dashboard_parent_folder}/{self.get_dashboard_folder_name()}"
+ )
if not os.path.exists(self.dashboard_folder):
os.makedirs(self.dashboard_folder)
-
+
if not continue_existing_dashboard:
# check if there are files there and if so abort
if len(os.listdir(self.dashboard_folder)) > 0:
raise ValueError("Dashboard folder not empty. Aborting.")
def get_feature_sparsity_path(self, wandb_artifact_path):
- prefix = wandb_artifact_path.split(':')[0]
+ prefix = wandb_artifact_path.split(":")[0]
return f"{prefix}_log_feature_sparsity:v9"
-
+
def get_dashboard_folder_name(self):
-
model = self.sparse_autoencoder.cfg.model_name
hook_point = self.sparse_autoencoder.cfg.hook_point
d_sae = self.sparse_autoencoder.cfg.d_sae
dashboard_folder_name = f"{model}_{hook_point}_{d_sae}"
-
+
return dashboard_folder_name
-
+
def init_sae_session(self):
-
- self.model, self.sparse_autoencoder, self.activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
- self.sae_path
- )
-
- def get_tokens(self, n_batches_to_sample_from = 2**12, n_prompts_to_select = 4096*6):
- '''
+ (
+ self.model,
+ self.sparse_autoencoder,
+ self.activation_store,
+ ) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)
+
+ def get_tokens(
+ self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6
+ ):
+ """
Get the tokens needed for dashboard generation.
- '''
-
+ """
+
all_tokens_list = []
pbar = tqdm(range(n_batches_to_sample_from))
for _ in pbar:
-
batch_tokens = self.activation_store.get_batch_tokens()
- batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][:batch_tokens.shape[0]]
+ batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][
+ : batch_tokens.shape[0]
+ ]
all_tokens_list.append(batch_tokens)
-
+
all_tokens = torch.cat(all_tokens_list, dim=0)
all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])]
return all_tokens[:n_prompts_to_select]
def get_index_to_resume_from(self):
-
for i in range(self.n_features):
-
if not os.path.exists(f"{self.dashboard_folder}/data_{i:04}.html"):
- break
+ break
n_features = self.sparse_autoencoder.cfg.d_sae
n_features_at_a_time = self.n_features_at_a_time
@@ -152,91 +159,119 @@ def get_index_to_resume_from(self):
n_features_remaining = self.final_index - id_of_last_feature_without_dashboard
n_batches_to_do = n_features_remaining // n_features_at_a_time
if self.final_index == n_features:
- id_to_start_from = max(0, n_features - (n_batches_to_do + 1) * n_features_at_a_time)
+ id_to_start_from = max(
+ 0, n_features - (n_batches_to_do + 1) * n_features_at_a_time
+ )
else:
- id_to_start_from = 0 # testing purposes only
-
-
+ id_to_start_from = 0 # testing purposes only
+
print(f"File {i} does not exist")
print(f"features left to do: {n_features_remaining}")
print(f"id_to_start_from: {id_to_start_from}")
- print(f"number of batches to do: {(n_features - id_to_start_from) // n_features_at_a_time}")
-
+ print(
+ f"number of batches to do: {(n_features - id_to_start_from) // n_features_at_a_time}"
+ )
+
return id_to_start_from
-
+
@torch.no_grad()
def get_feature_property_df(self):
-
- sparse_autoencoder= self.sparse_autoencoder
+ sparse_autoencoder = self.sparse_autoencoder
feature_sparsity = self.feature_sparsity
-
- W_dec_normalized = sparse_autoencoder.W_dec.cpu()# / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True)
- W_enc_normalized = sparse_autoencoder.W_enc.cpu() / sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True)
+
+ W_dec_normalized = (
+ sparse_autoencoder.W_dec.cpu()
+ ) # / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True)
+ W_enc_normalized = (
+ sparse_autoencoder.W_enc.cpu()
+ / sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True)
+ )
d_e_projection = cosine_similarity(W_dec_normalized, W_enc_normalized.T)
b_dec_projection = sparse_autoencoder.b_dec.cpu() @ W_dec_normalized.T
- temp_df = pd.DataFrame({
- "log_feature_sparsity": feature_sparsity + 1e-10,
- "d_e_projection": d_e_projection,
- # "d_e_projection_normalized": d_e_projection_normalized,
- "b_enc": sparse_autoencoder.b_enc.detach().cpu(),
- "feature": [f"feature_{i}" for i in range(sparse_autoencoder.cfg.d_sae)],
- "index": torch.arange(sparse_autoencoder.cfg.d_sae),
- "dead_neuron": (feature_sparsity < -9).cpu(),
- })
-
+ temp_df = pd.DataFrame(
+ {
+ "log_feature_sparsity": feature_sparsity + 1e-10,
+ "d_e_projection": d_e_projection,
+ # "d_e_projection_normalized": d_e_projection_normalized,
+ "b_enc": sparse_autoencoder.b_enc.detach().cpu(),
+ "feature": [
+ f"feature_{i}" for i in range(sparse_autoencoder.cfg.d_sae)
+ ],
+ "index": torch.arange(sparse_autoencoder.cfg.d_sae),
+ "dead_neuron": (feature_sparsity < -9).cpu(),
+ }
+ )
+
return temp_df
-
-
+
def run(self):
- '''
+ """
Generate the dashboard.
- '''
-
+ """
+
if self.use_wandb:
# get name from wandb
- random_suffix= str(uuid.uuid4())[:8]
+ random_suffix = str(uuid.uuid4())[:8]
name = f"{self.get_dashboard_folder_name()}_{random_suffix}"
run = wandb.init(
project="feature_dashboards",
config=self.sparse_autoencoder.cfg,
- name = name,
- tags = [
+ name=name,
+ tags=[
f"model_{self.sparse_autoencoder.cfg.model_name}",
f"hook_point_{self.sparse_autoencoder.cfg.hook_point}",
- ]
+ ],
)
-
+
if self.model is None:
self.init_sae_session()
-
# generate all the plots
if self.use_wandb:
feature_property_df = self.get_feature_property_df()
-
- fig = px.histogram(runner.feature_sparsity+1e-10, nbins=100, log_x=False, title="Feature sparsity")
- wandb.log({"plots/feature_density_histogram": wandb.Html(plotly.io.to_html(fig))})
- fig = px.histogram(self.sparse_autoencoder.b_enc.detach().cpu(), title = "b_enc", nbins = 100)
+ fig = px.histogram(
+ runner.feature_sparsity + 1e-10,
+ nbins=100,
+ log_x=False,
+ title="Feature sparsity",
+ )
+ wandb.log(
+ {"plots/feature_density_histogram": wandb.Html(plotly.io.to_html(fig))}
+ )
+
+ fig = px.histogram(
+ self.sparse_autoencoder.b_enc.detach().cpu(), title="b_enc", nbins=100
+ )
wandb.log({"plots/b_enc_histogram": wandb.Html(plotly.io.to_html(fig))})
-
- fig = px.histogram(feature_property_df.d_e_projection, nbins = 100, title = "D/E projection")
- wandb.log({"plots/d_e_projection_histogram": wandb.Html(plotly.io.to_html(fig))})
-
- fig = px.histogram(self.sparse_autoencoder.b_dec.detach().cpu(), nbins=100, title = "b_dec projection onto W_dec")
- wandb.log({"plots/b_dec_projection_histogram": wandb.Html(plotly.io.to_html(fig))})
-
- fig = px.scatter_matrix(feature_property_df,
- dimensions = ["log_feature_sparsity", "d_e_projection", "b_enc"],
+
+ fig = px.histogram(
+ feature_property_df.d_e_projection, nbins=100, title="D/E projection"
+ )
+ wandb.log(
+ {"plots/d_e_projection_histogram": wandb.Html(plotly.io.to_html(fig))}
+ )
+
+ fig = px.histogram(
+ self.sparse_autoencoder.b_dec.detach().cpu(),
+ nbins=100,
+ title="b_dec projection onto W_dec",
+ )
+ wandb.log(
+ {"plots/b_dec_projection_histogram": wandb.Html(plotly.io.to_html(fig))}
+ )
+
+ fig = px.scatter_matrix(
+ feature_property_df,
+ dimensions=["log_feature_sparsity", "d_e_projection", "b_enc"],
color="dead_neuron",
hover_name="feature",
opacity=0.2,
height=800,
- width =1400,
+ width=1400,
)
wandb.log({"plots/scatter_matrix": wandb.Html(plotly.io.to_html(fig))})
-
self.n_features = self.sparse_autoencoder.cfg.d_sae
id_to_start_from = self.get_index_to_resume_from()
@@ -246,19 +281,21 @@ def run(self):
feature_idx = torch.tensor(range(id_to_start_from, id_to_end_at))
feature_idx = feature_idx.reshape(-1, self.n_features_at_a_time)
feature_idx = [x.tolist() for x in feature_idx]
-
+
print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}")
print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}")
print(f"Writing files to: {self.dashboard_folder}")
# get tokens:
start = time.time()
- tokens = self.get_tokens(self.n_batches_to_sample_from, self.n_prompts_to_select)
+ tokens = self.get_tokens(
+ self.n_batches_to_sample_from, self.n_prompts_to_select
+ )
end = time.time()
print(f"Time to get tokens: {end - start}")
if self.use_wandb:
wandb.log({"time/time_to_get_tokens": end - start})
-
+
with torch.no_grad():
for interesting_features in tqdm(feature_idx):
print(interesting_features)
@@ -272,63 +309,71 @@ def run(self):
tokens=tokens,
feature_idx=interesting_features,
max_batch_size=self.max_batch_size,
- left_hand_k = 3,
- buffer = (self.buffer_tokens, self.buffer_tokens),
- n_groups = 10,
- first_group_size = 20,
- other_groups_size = 5,
- verbose = True,
+ left_hand_k=3,
+ buffer=(self.buffer_tokens, self.buffer_tokens),
+ n_groups=10,
+ first_group_size=20,
+ other_groups_size=5,
+ verbose=True,
)
-
+
for i, test_idx in enumerate(feature_data.keys()):
html_str = feature_data[test_idx].get_all_html()
- with open(f"{self.dashboard_folder}/data_{test_idx:04}.html", "w") as f:
+ with open(
+ f"{self.dashboard_folder}/data_{test_idx:04}.html", "w"
+ ) as f:
f.write(html_str)
-
+
if i < 10 and self.use_wandb:
# upload the html as an artifact
artifact = wandb.Artifact(f"feature_{test_idx}", type="feature")
- artifact.add_file(f"{self.dashboard_folder}/data_{test_idx:04}.html")
+ artifact.add_file(
+ f"{self.dashboard_folder}/data_{test_idx:04}.html"
+ )
run.log_artifact(artifact)
-
+
# also upload as html to dashboard
wandb.log(
- {f"features/feature_dashboard": wandb.Html(f"{self.dashboard_folder}/data_{test_idx:04}.html")},
- step = test_idx
- )
-
+ {
+ f"features/feature_dashboard": wandb.Html(
+ f"{self.dashboard_folder}/data_{test_idx:04}.html"
+ )
+ },
+ step=test_idx,
+ )
+
# when done zip the folder
- shutil.make_archive(self.dashboard_folder, 'zip', self.dashboard_folder)
-
+ shutil.make_archive(self.dashboard_folder, "zip", self.dashboard_folder)
+
# then upload the zip as an artifact
artifact = wandb.Artifact("dashboard", type="zipped_feature_dashboards")
artifact.add_file(f"{self.dashboard_folder}.zip")
run.log_artifact(artifact)
-
+
# terminate the run
run.finish()
-
+
# delete the dashboard folder
shutil.rmtree(self.dashboard_folder)
-
- return
+
+ return
-# test it
+# test it
runner = DashboardRunner(
- sae_path = None,
- dashboard_parent_folder = "../feature_dashboards",
- wandb_artifact_path = "jbloom/mats_sae_training_gpt2_small_resid_pre_5/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v19",
- init_session = True,
- n_batches_to_sample_from = 2**12,
- n_prompts_to_select = 4096*6,
- n_features_at_a_time = 128,
- max_batch_size = 256,
- buffer_tokens = 8,
- use_wandb = True,
- continue_existing_dashboard = True,
+ sae_path=None,
+ dashboard_parent_folder="../feature_dashboards",
+ wandb_artifact_path="jbloom/mats_sae_training_gpt2_small_resid_pre_5/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v19",
+ init_session=True,
+ n_batches_to_sample_from=2**12,
+ n_prompts_to_select=4096 * 6,
+ n_features_at_a_time=128,
+ max_batch_size=256,
+ buffer_tokens=8,
+ use_wandb=True,
+ continue_existing_dashboard=True,
)
-runner.run()
\ No newline at end of file
+runner.run()
diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py
index 28f064ee..ade45557 100644
--- a/tests/benchmark/test_language_model_sae_runner.py
+++ b/tests/benchmark/test_language_model_sae_runner.py
@@ -5,62 +5,51 @@
def test_language_model_sae_runner():
-
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
-
- cfg = LanguageModelSAERunnerConfig(
+ cfg = LanguageModelSAERunnerConfig(
# Data Generating Function (Model + Training Distibuion)
- model_name = "gelu-2l",
- hook_point = "blocks.0.hook_mlp_out",
- hook_point_layer = 0,
- d_in = 512,
- dataset_path = "NeelNanda/c4-tokenized-2b",
+ model_name="gelu-2l",
+ hook_point="blocks.0.hook_mlp_out",
+ hook_point_layer=0,
+ d_in=512,
+ dataset_path="NeelNanda/c4-tokenized-2b",
is_dataset_tokenized=True,
-
# SAE Parameters
- expansion_factor = 16,
- b_dec_init_method="mean", # not ideal but quicker when testing code.
-
+ expansion_factor=16,
+ b_dec_init_method="mean", # not ideal but quicker when testing code.
# Training Parameters
- lr = 1e-4,
- l1_coefficient = 3e-4,
- train_batch_size = 4096,
- context_size = 128,
-
+ lr=1e-4,
+ l1_coefficient=3e-4,
+ train_batch_size=4096,
+ context_size=128,
# Activation Store Parameters
- n_batches_in_buffer = 24,
- total_training_tokens = 1_000_000 * 10,
- store_batch_size = 32,
-
+ n_batches_in_buffer=24,
+ total_training_tokens=1_000_000 * 10,
+ store_batch_size=32,
# Resampling protocol
use_ghost_grads=True,
-
- feature_sampling_window = 3000, # in steps
+ feature_sampling_window=3000, # in steps
dead_feature_window=5000,
- dead_feature_threshold = 1e-8,
-
+ dead_feature_threshold=1e-8,
# WANDB
- log_to_wandb = True,
- wandb_project= "mats_sae_training_benchmarks",
- wandb_entity = None,
-
+ log_to_wandb=True,
+ wandb_project="mats_sae_training_benchmarks",
+ wandb_entity=None,
# Misc
- device = device,
- seed = 42,
- n_checkpoints = 5,
- checkpoint_path = "checkpoints",
- dtype = torch.float32,
- )
+ device=device,
+ seed=42,
+ n_checkpoints=5,
+ checkpoint_path="checkpoints",
+ dtype=torch.float32,
+ )
sparse_autoencoder = language_model_sae_runner(cfg)
assert sparse_autoencoder is not None
# know whether or not this works by looking at the dashbaord!
-
-
diff --git a/tests/benchmark/test_toy_model_sae_runner.py b/tests/benchmark/test_toy_model_sae_runner.py
index 8bf762bf..811ac9c4 100644
--- a/tests/benchmark/test_toy_model_sae_runner.py
+++ b/tests/benchmark/test_toy_model_sae_runner.py
@@ -6,16 +6,14 @@
# @pytest.mark.skip(reason="I (joseph) broke this at some point, on my to do list to fix.")
def test_toy_model_sae_runner():
-
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
-
+
cfg = SAEToyModelRunnerConfig(
-
# Model Details
n_features=100,
n_hidden=10,
@@ -23,23 +21,20 @@ def test_toy_model_sae_runner():
n_anticorrelated_pairs=0,
feature_probability=0.025,
model_training_steps=10_000,
-
# SAE Parameters
d_sae=10,
- lr = 3e-4,
+ lr=3e-4,
l1_coefficient=0.001,
use_ghost_grads=False,
b_dec_init_method="mean",
-
# SAE Train Config
train_batch_size=1028,
feature_sampling_window=3_000,
dead_feature_window=1_000,
- total_training_tokens=4096*1000,
-
+ total_training_tokens=4096 * 1000,
# Other parameters
log_to_wandb=True,
- wandb_project= "mats_sae_training_benchmarks_toy",
+ wandb_project="mats_sae_training_benchmarks_toy",
wandb_log_frequency=5,
device=device,
)
diff --git a/tests/unit/test_activations_store.py b/tests/unit/test_activations_store.py
index 9db4f593..f6a6b68c 100644
--- a/tests/unit/test_activations_store.py
+++ b/tests/unit/test_activations_store.py
@@ -33,16 +33,16 @@ def cfg():
mock_config.context_size = 16
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = None
-
+
mock_config.feature_sampling_method = None
mock_config.feature_sampling_window = 50
mock_config.feature_reinit_scale = 0.1
mock_config.dead_feature_threshold = 1e-7
-
+
mock_config.n_batches_in_buffer = 4
mock_config.total_training_tokens = 1_000_000
mock_config.store_batch_size = 32
-
+
mock_config.log_to_wandb = False
mock_config.wandb_project = "test_project"
mock_config.wandb_entity = "test_entity"
@@ -50,12 +50,11 @@ def cfg():
mock_config.device = torch.device("cpu")
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
- mock_config.dtype = torch.float32
+ mock_config.dtype = torch.float32
return mock_config
-
@pytest.fixture
def cfg_head_hook():
"""
@@ -78,17 +77,16 @@ def cfg_head_hook():
mock_config.context_size = 128
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = 0
-
-
+
mock_config.feature_sampling_method = None
mock_config.feature_sampling_window = 50
mock_config.feature_reinit_scale = 0.1
mock_config.dead_feature_threshold = 1e-7
-
+
mock_config.n_batches_in_buffer = 4
mock_config.total_training_tokens = 1_000_000
mock_config.store_batch_size = 32
-
+
mock_config.log_to_wandb = False
mock_config.wandb_project = "test_project"
mock_config.wandb_entity = "test_entity"
@@ -96,7 +94,7 @@ def cfg_head_hook():
mock_config.device = torch.device("cpu")
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
- mock_config.dtype = torch.float32
+ mock_config.dtype = torch.float32
return mock_config
@@ -105,74 +103,77 @@ def cfg_head_hook():
def model():
return HookedTransformer.from_pretrained(TEST_MODEL, device="cpu")
+
@pytest.fixture
def activation_store(cfg, model):
return ActivationsStore(cfg, model)
+
@pytest.fixture
def activation_store_head_hook(cfg_head_hook, model):
return ActivationsStore(cfg_head_hook, model)
+
def test_activations_store__init__(cfg, model):
-
store = ActivationsStore(cfg, model)
-
+
assert store.cfg == cfg
assert store.model == model
-
+
assert isinstance(store.dataset, IterableDataset)
assert isinstance(store.iterable_dataset, Iterable)
-
+
# I expect the dataloader to be initialised
assert hasattr(store, "dataloader")
-
+
# I expect the buffer to be initialised
assert hasattr(store, "storage_buffer")
-
+
# the rest is in the dataloader.
- expected_size = cfg.store_batch_size*cfg.context_size*cfg.n_batches_in_buffer //2
+ expected_size = (
+ cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer // 2
+ )
assert store.storage_buffer.shape == (expected_size, cfg.d_in)
-
-
+
+
def test_activations_store__get_batch_tokens(activation_store):
-
batch = activation_store.get_batch_tokens()
-
+
assert isinstance(batch, torch.Tensor)
- assert batch.shape == (activation_store.cfg.store_batch_size, activation_store.cfg.context_size)
+ assert batch.shape == (
+ activation_store.cfg.store_batch_size,
+ activation_store.cfg.context_size,
+ )
assert batch.device == activation_store.cfg.device
-
+
+
def test_activations_store__get_activations(activation_store):
-
batch = activation_store.get_batch_tokens()
activations = activation_store.get_activations(batch)
-
+
cfg = activation_store.cfg
assert isinstance(activations, torch.Tensor)
assert activations.shape == (cfg.store_batch_size, cfg.context_size, cfg.d_in)
assert activations.device == cfg.device
-
+
+
def test_activations_store__get_activations_head_hook(activation_store_head_hook):
-
batch = activation_store_head_hook.get_batch_tokens()
activations = activation_store_head_hook.get_activations(batch)
-
+
cfg = activation_store_head_hook.cfg
assert isinstance(activations, torch.Tensor)
assert activations.shape == (cfg.store_batch_size, cfg.context_size, cfg.d_in)
assert activations.device == cfg.device
-
+
+
def test_activations_store__get_buffer(activation_store):
-
-
n_batches_in_buffer = 3
buffer = activation_store.get_buffer(n_batches_in_buffer)
cfg = activation_store.cfg
assert isinstance(buffer, torch.Tensor)
- buffer_size_expected = cfg.store_batch_size*cfg.context_size* n_batches_in_buffer
-
+ buffer_size_expected = cfg.store_batch_size * cfg.context_size * n_batches_in_buffer
+
assert buffer.shape == (buffer_size_expected, cfg.d_in)
assert buffer.device == cfg.device
-
-
diff --git a/tests/unit/test_sparse_autoencoder.py b/tests/unit/test_sparse_autoencoder.py
index 7224893a..63422b09 100644
--- a/tests/unit/test_sparse_autoencoder.py
+++ b/tests/unit/test_sparse_autoencoder.py
@@ -12,6 +12,7 @@
TEST_MODEL = "tiny-stories-1M"
TEST_DATASET = "roneneldan/TinyStories"
+
@pytest.fixture
def cfg():
"""
@@ -48,11 +49,12 @@ def cfg():
mock_config.device = "cpu"
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
- # mock_config.dtype = torch.bfloat16
+ # mock_config.dtype = torch.bfloat16
mock_config.dtype = torch.float32
return mock_config
+
@pytest.fixture
def sparse_autoencoder(cfg):
"""
@@ -60,142 +62,161 @@ def sparse_autoencoder(cfg):
"""
return SparseAutoencoder(cfg)
+
@pytest.fixture
def model():
return HookedTransformer.from_pretrained(TEST_MODEL)
+
@pytest.fixture
def activation_store(cfg, model):
return ActivationsStore(cfg, model)
+
def test_sparse_autoencoder_init(cfg):
-
sparse_autoencoder = SparseAutoencoder(cfg)
-
+
assert isinstance(sparse_autoencoder, SparseAutoencoder)
-
- assert sparse_autoencoder.W_enc.shape == (cfg.d_in, cfg.d_sae)
+
+ assert sparse_autoencoder.W_enc.shape == (cfg.d_in, cfg.d_sae)
assert sparse_autoencoder.W_dec.shape == (cfg.d_sae, cfg.d_in)
assert sparse_autoencoder.b_enc.shape == (cfg.d_sae,)
assert sparse_autoencoder.b_dec.shape == (cfg.d_in,)
-
+
# assert decoder columns have unit norm
assert torch.allclose(
- torch.norm(sparse_autoencoder.W_dec, dim=1),
- torch.ones(cfg.d_sae)
+ torch.norm(sparse_autoencoder.W_dec, dim=1), torch.ones(cfg.d_sae)
)
+
def test_save_model(cfg):
-
with tempfile.TemporaryDirectory() as tmpdirname:
-
# assert file does not exist
assert os.path.exists(tmpdirname + "/test.pt") == False
-
+
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder.save_model(tmpdirname + "/test.pt")
-
+
assert os.path.exists(tmpdirname + "/test.pt")
-
+
state_dict_original = sparse_autoencoder.state_dict()
state_dict_loaded = torch.load(tmpdirname + "/test.pt")
-
+
# check for cfg and state_dict keys
assert "cfg" in state_dict_loaded
assert "state_dict" in state_dict_loaded
-
+
# check cfg matches the original
assert state_dict_loaded["cfg"] == cfg
-
+
# check state_dict matches the original
for key in sparse_autoencoder.state_dict().keys():
assert torch.allclose(
state_dict_original[key], # pylint: disable=unsubscriptable-object
- state_dict_loaded["state_dict"][key]
+ state_dict_loaded["state_dict"][key],
)
+
def test_load_from_pretrained_pt(cfg):
-
with tempfile.TemporaryDirectory() as tmpdirname:
-
# assert file does not exist
assert os.path.exists(tmpdirname + "/test.pt") == False
-
+
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder_state_dict = sparse_autoencoder.state_dict()
sparse_autoencoder.save_model(tmpdirname + "/test.pt")
-
+
assert os.path.exists(tmpdirname + "/test.pt")
-
- sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(tmpdirname + "/test.pt")
- sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps
+
+ sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(
+ tmpdirname + "/test.pt"
+ )
+ sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps
sparse_autoencoder_loaded = sparse_autoencoder_loaded.to("cpu")
sparse_autoencoder_loaded_state_dict = sparse_autoencoder_loaded.state_dict()
# check cfg matches the original
assert sparse_autoencoder_loaded.cfg == cfg
-
+
# check state_dict matches the original
for key in sparse_autoencoder.state_dict().keys():
assert torch.allclose(
- sparse_autoencoder_state_dict[key], # pylint: disable=unsubscriptable-object
- sparse_autoencoder_loaded_state_dict[key] # pylint: disable=unsubscriptable-object
+ sparse_autoencoder_state_dict[
+ key
+ ], # pylint: disable=unsubscriptable-object
+ sparse_autoencoder_loaded_state_dict[
+ key
+ ], # pylint: disable=unsubscriptable-object
)
-
+
+
def test_load_from_pretrained_pkl_gz(cfg):
-
with tempfile.TemporaryDirectory() as tmpdirname:
-
# assert file does not exist
assert os.path.exists(tmpdirname + "/test.pkl.gz") == False
-
+
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder_state_dict = sparse_autoencoder.state_dict()
sparse_autoencoder.save_model(tmpdirname + "/test.pkl.gz")
-
+
assert os.path.exists(tmpdirname + "/test.pkl.gz")
-
- sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(tmpdirname + "/test.pkl.gz")
- sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps
+
+ sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(
+ tmpdirname + "/test.pkl.gz"
+ )
+ sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps
sparse_autoencoder_loaded = sparse_autoencoder_loaded.to("cpu")
sparse_autoencoder_loaded_state_dict = sparse_autoencoder_loaded.state_dict()
# check cfg matches the original
assert sparse_autoencoder_loaded.cfg == cfg
-
+
# check state_dict matches the original
for key in sparse_autoencoder.state_dict().keys():
assert torch.allclose(
- sparse_autoencoder_state_dict[key], # pylint: disable=unsubscriptable-object
- sparse_autoencoder_loaded_state_dict[key] # pylint: disable=unsubscriptable-object
+ sparse_autoencoder_state_dict[
+ key
+ ], # pylint: disable=unsubscriptable-object
+ sparse_autoencoder_loaded_state_dict[
+ key
+ ], # pylint: disable=unsubscriptable-object
)
-
+
+
def test_sparse_autoencoder_forward(sparse_autoencoder):
-
batch_size = 32
- d_in =sparse_autoencoder.d_in
+ d_in = sparse_autoencoder.d_in
d_sae = sparse_autoencoder.d_sae
-
+
x = torch.randn(batch_size, d_in)
- sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss = sparse_autoencoder.forward(
+ (
+ sae_out,
+ feature_acts,
+ loss,
+ mse_loss,
+ l1_loss,
+ ghost_grad_loss,
+ ) = sparse_autoencoder.forward(
x,
)
-
+
assert sae_out.shape == (batch_size, d_in)
assert feature_acts.shape == (batch_size, d_sae)
assert loss.shape == ()
assert mse_loss.shape == ()
assert l1_loss.shape == ()
assert torch.allclose(loss, mse_loss + l1_loss)
-
+
x_centred = x - x.mean(dim=0, keepdim=True)
- expected_mse_loss = (torch.pow((sae_out-x.float()), 2) / (x_centred**2).sum(dim=-1, keepdim=True).sqrt()).mean()
+ expected_mse_loss = (
+ torch.pow((sae_out - x.float()), 2)
+ / (x_centred**2).sum(dim=-1, keepdim=True).sqrt()
+ ).mean()
assert torch.allclose(mse_loss, expected_mse_loss)
- expected_l1_loss = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
+ expected_l1_loss = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
assert torch.allclose(l1_loss, sparse_autoencoder.l1_coefficient * expected_l1_loss)
-
+
# check everything has the right dtype
assert sae_out.dtype == sparse_autoencoder.dtype
assert feature_acts.dtype == sparse_autoencoder.dtype
assert loss.dtype == sparse_autoencoder.dtype
assert mse_loss.dtype == sparse_autoencoder.dtype
assert l1_loss.dtype == sparse_autoencoder.dtype
-
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index 0ee2a68a..787242de 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -12,6 +12,7 @@
TEST_MODEL = "tiny-stories-1M"
TEST_DATASET = "roneneldan/TinyStories"
+
@pytest.fixture
def cfg():
"""
@@ -45,7 +46,7 @@ def cfg():
mock_config.device = "cpu"
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
- mock_config.dtype = torch.float32
+ mock_config.dtype = torch.float32
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = None
@@ -53,36 +54,36 @@ def cfg():
def test_LMSparseAutoencoderSessionloader_init(cfg):
-
loader = LMSparseAutoencoderSessionloader(cfg)
assert loader.cfg == cfg
-
+
+
def test_LMSparseAutoencoderSessionloader_load_session(cfg):
-
loader = LMSparseAutoencoderSessionloader(cfg)
model, sparse_autoencoder, activations_loader = loader.load_session()
-
+
assert isinstance(model, HookedTransformer)
assert isinstance(sparse_autoencoder, SparseAutoencoder)
assert isinstance(activations_loader, ActivationsStore)
def test_LMSparseAutoencoderSessionloader_load_session_from_trained(cfg):
-
loader = LMSparseAutoencoderSessionloader(cfg)
_, sparse_autoencoder, _ = loader.load_session()
-
+
with tempfile.TemporaryDirectory() as tmpdirname:
tempfile_path = f"{tmpdirname}/test.pt"
sparse_autoencoder.save_model(tempfile_path)
-
- _, new_sparse_autoencoder, _ = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
- tempfile_path
- )
+
+ (
+ _,
+ new_sparse_autoencoder,
+ _,
+ ) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(tempfile_path)
new_sparse_autoencoder.cfg.device = "cpu"
new_sparse_autoencoder.to("cpu")
assert new_sparse_autoencoder.cfg == sparse_autoencoder.cfg
# assert weights are the same
new_parameters = dict(new_sparse_autoencoder.named_parameters())
for name, param in sparse_autoencoder.named_parameters():
- assert torch.allclose(param, new_parameters[name])
\ No newline at end of file
+ assert torch.allclose(param, new_parameters[name])