Skip to content

Commit

Permalink
flake8 linter changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Benw8888 committed Feb 28, 2024
1 parent 082c813 commit 8e41e59
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 104 deletions.
49 changes: 32 additions & 17 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def __init__(
# Check if dataset is tokenized
dataset_sample = next(self.iterable_dataset)
self.cfg.is_dataset_tokenized = "tokens" in dataset_sample.keys()
print(f"Dataset is {'tokenized' if self.cfg.is_dataset_tokenized else 'not tokenized'}! Updating config.")
print(
f"Dataset is {'tokenized' if self.cfg.is_dataset_tokenized else 'not tokenized'}! Updating config."
)
self.iterable_dataset = iter(self.dataset) # Reset iterator after checking

if self.cfg.use_cached_activations: #EDIT: load from multi-layer acts
if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts
# Sanity check: does the cache directory exist?
assert os.path.exists(
self.cfg.cached_activations_path
Expand Down Expand Up @@ -145,45 +147,57 @@ def get_activations(self, batch_tokens, get_loss=False):
"""
Returns activations of shape (batches, context, num_layers, d_in)
"""
layers = self.cfg.hook_point_layer if isinstance(self.cfg.hook_point_layer, list) else [self.cfg.hook_point_layer]
act_names = [self.cfg.hook_point.format(layer = layer) for layer in layers]
layers = (
self.cfg.hook_point_layer
if isinstance(self.cfg.hook_point_layer, list)
else [self.cfg.hook_point_layer]
)
act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers]
hook_point_max_layer = max(layers)
if self.cfg.hook_point_head_index is not None:
layerwise_activations = self.model.run_with_cache(
batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index]
layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index]
for act_name in act_names
]
else:
layerwise_activations = self.model.run_with_cache(
batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name]
for act_name in act_names
layerwise_activations[act_name] for act_name in act_names
]

# Stack along a new dimension to keep separate layers distinct
stacked_activations = torch.stack(activations_list, dim=2)

return stacked_activations

def get_buffer(self, n_batches_in_buffer):
context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = len(self.cfg.hook_point_layer) if isinstance(self.cfg.hook_point_layer, list) \
else 1 # Number of hook points or layers
num_layers = (
len(self.cfg.hook_point_layer)
if isinstance(self.cfg.hook_point_layer, list)
else 1
) # Number of hook points or layers

if self.cfg.use_cached_activations:
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, num_layers, d_in), dtype=self.cfg.dtype, device=self.cfg.device
(buffer_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
)
n_tokens_filled = 0

Expand Down Expand Up @@ -215,8 +229,9 @@ def get_buffer(self, n_batches_in_buffer):
activations = activations[: buffer_size - n_tokens_filled, ...]
taking_subset_of_file = True

new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0], ...] = activations

new_buffer[
n_tokens_filled : n_tokens_filled + activations.shape[0], ...
] = activations

if taking_subset_of_file:
self.next_idx_within_buffer = activations.shape[0]
Expand All @@ -235,7 +250,7 @@ def get_buffer(self, n_batches_in_buffer):
dtype=self.cfg.dtype,
device=self.cfg.device,
)

for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens)
Expand Down
10 changes: 4 additions & 6 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class RunnerConfig(ABC):
is_dataset_tokenized: bool = True
context_size: int = 128
use_cached_activations: bool = False
cached_activations_path: Optional[str] = (
None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"
)
cached_activations_path: Optional[
str
] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"

# SAE Parameters
d_in: int = 512
Expand Down Expand Up @@ -61,9 +61,7 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
l1_coefficient: float = 1e-3
lp_norm: float = 1
lr: float = 3e-4
lr_scheduler_name: str = (
"constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
)
lr_scheduler_name: str = "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_warm_up_steps: int = 500
train_batch_size: int = 4096

Expand Down
13 changes: 7 additions & 6 deletions sae_training/sae_group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
import gzip
import os
import pickle
Expand Down Expand Up @@ -29,7 +28,9 @@ def _init_autoencoders(self, cfg):
params = dict(zip(keys, combination))
cfg_copy = dataclasses.replace(cfg, **params)
# Insert the layer into the hookpoint
cfg_copy.hook_point = cfg_copy.hook_point.format(layer=cfg_copy.hook_point_layer)
cfg_copy.hook_point = cfg_copy.hook_point.format(
layer=cfg_copy.hook_point_layer
)
# Create and store both the SparseAutoencoder instance and its parameters
self.autoencoders.append(SparseAutoencoder(cfg_copy))

Expand All @@ -41,7 +42,7 @@ def __iter__(self):
def __len__(self):
# Return the number of SparseAutoencoder instances
return len(self.autoencoders)

@classmethod
def load_from_pretrained(cls, path: str):
"""
Expand Down Expand Up @@ -95,7 +96,7 @@ def load_from_pretrained(cls, path: str):
# instance.load_state_dict(state_dict["state_dict"])

# return instance

def save_model(self, path: str):
"""
Basic save function for the model. Saves the model's state_dict and the config used to train it.
Expand All @@ -116,7 +117,7 @@ def save_model(self, path: str):
)

print(f"Saved model to {path}")

def get_name(self):
layers = self.cfg.hook_point_layer
if not isinstance(layers, list):
Expand All @@ -126,4 +127,4 @@ def get_name(self):
else:
layer_string = f"{layers[0]}"
sae_name = f"sae_group_{self.cfg.model_name}_{self.cfg.hook_point.format(layer=layer_string)}_{self.cfg.d_sae}"
return sae_name
return sae_name
2 changes: 1 addition & 1 deletion sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self, x, dead_neuron_mask=None):
loss = mse_loss + l1_loss + mse_loss_ghost_resid

return sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid

@torch.no_grad()
def initialize_b_dec_with_precalculated(self, origin):
out = torch.tensor(origin, dtype=self.dtype, device=self.device)
Expand Down
Loading

0 comments on commit 8e41e59

Please sign in to comment.