Skip to content

Commit

Permalink
implemented sweeping via config list
Browse files Browse the repository at this point in the history
  • Loading branch information
Benw8888 committed Feb 27, 2024
1 parent 2ba2131 commit 80f61fa
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 76 deletions.
2 changes: 1 addition & 1 deletion activation_storing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# Activation Store Parameters
n_batches_in_buffer = 16,
total_training_tokens = 300_000_000,
total_training_tokens = 10_000_000,
store_batch_size = 64,

# Activation caching shuffle parameters
Expand Down
32 changes: 26 additions & 6 deletions lp_sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

# Data Generating Function (Model + Training Distibuion)
model_name = "gpt2-small",
hook_point = "blocks.2.hook_resid_pre",
hook_point_layer = 2,
hook_point = "blocks.{layer}.hook_resid_pre",
hook_point_layer = 6,
d_in = 768,
dataset_path = "Skylion007/openwebtext",
is_dataset_tokenized=False,
Expand All @@ -24,15 +24,35 @@

# Training Parameters
lr = 4e-4,
l1_coefficient = 8e-5,
l1_coefficient = [ #8e-5
# 1e-9,
# 1e-8,
1e-7,
1e-6,
1e-5,
1e-4,
],
lp_norm = [
# 0.1,
# 0.2,
# 0.3,
# 0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1,
# 1.1,
],
lr_scheduler_name="constantwithwarmup",
train_batch_size = 4096,
context_size = 128,
lr_warm_up_steps=5000,

# Activation Store Parameters
n_batches_in_buffer = 128,
total_training_tokens = 1_000_000 * 300,
total_training_tokens = 300_000_000,
store_batch_size = 32,

# Dead Neurons and Sparsity
Expand All @@ -50,10 +70,10 @@
# Misc
device = "cuda",
seed = 42,
n_checkpoints = 10,
n_checkpoints = 2,
checkpoint_path = "checkpoints",
dtype = torch.float32,
use_cached_activations = True,
use_cached_activations = False,
)

sparse_autoencoder = language_model_sae_runner(cfg)
7 changes: 4 additions & 3 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_activations(self, batch_tokens, get_loss=False):
"""
layers = self.cfg.hook_point_layer if isinstance(self.cfg.hook_point_layer, list) else [self.cfg.hook_point_layer]
act_names = [self.cfg.hook_point.format(layer = layer) for layer in layers]
hook_point_max_layer = self.cfg.hook_point_layer
hook_point_max_layer = max(layers)
if self.cfg.hook_point_head_index is not None:
layerwise_activations = self.model.run_with_cache(
batch_tokens, names_filter=act_names, stop_at_layer=hook_point_max_layer + 1
Expand Down Expand Up @@ -176,7 +176,8 @@ def get_buffer(self, n_batches_in_buffer):
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = len(self.cfg.hook_points) # Number of hook points or layers
num_layers = len(self.cfg.hook_point_layer) if isinstance(self.cfg.hook_point_layer, list) \
else 1 # Number of hook points or layers

if self.cfg.use_cached_activations:
# Load the activations from disk
Expand Down Expand Up @@ -268,7 +269,7 @@ def get_data_loader(
dim=0,
)

mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[1])]
mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

# 2. put 50 % in storage
self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
Expand Down
16 changes: 9 additions & 7 deletions sae_training/sae_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,24 @@ def _init_autoencoders(self, cfg):

# Extract all hyperparameter lists from cfg
hyperparameters = {k: v for k, v in vars(cfg).items() if isinstance(v, list)}
keys, values = zip(*hyperparameters.items())
if len(hyperparameters) > 0:
keys, values = zip(*hyperparameters.items())
else:
keys, values = (), ([()],) # Ensure product(*values) yields one combination

# Create all combinations of hyperparameters
for combination in product(*values):
cfg_copy = dataclasses.replace(cfg)
params = dict(zip(keys, combination))
cfg_copy.update(params)
cfg_copy = dataclasses.replace(cfg, **params)
# Insert the layer into the hookpoint
cfg_copy.hook_point = cfg_copy.hook_point.format(layer=cfg.copy.hook_point_layer)
cfg_copy.hook_point = cfg_copy.hook_point.format(layer=cfg_copy.hook_point_layer)
# Create and store both the SparseAutoencoder instance and its parameters
self.autoencoders.append((SparseAutoencoder(cfg_copy), cfg_copy))
self.autoencoders.append(SparseAutoencoder(cfg_copy))

def __iter__(self):
# Make SAEGroup iterable over its SparseAutoencoder instances and their parameters
for ae, params in self.autoencoders:
yield ae, params # Yielding as a tuple
for ae in self.autoencoders:
yield ae # Yielding as a tuple

def __len__(self):
# Return the number of SparseAutoencoder instances
Expand Down
17 changes: 10 additions & 7 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,18 @@ def forward(self, x, dead_neuron_mask=None):
loss = mse_loss + l1_loss + mse_loss_ghost_resid

return sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid

@torch.no_grad()
def initialize_b_dec_with_precalculated(self, origin):
out = torch.tensor(origin, dtype=self.dtype, device=self.device)
self.b_dec.data = out

@torch.no_grad()
def initialize_b_dec(self, activation_store):
def initialize_b_dec(self, all_activations):
if self.cfg.b_dec_init_method == "geometric_median":
self.initialize_b_dec_with_geometric_median(activation_store)
self.initialize_b_dec_with_geometric_median(all_activations)
elif self.cfg.b_dec_init_method == "mean":
self.initialize_b_dec_with_mean(activation_store)
self.initialize_b_dec_with_mean(all_activations)
elif self.cfg.b_dec_init_method == "zeros":
pass
else:
Expand All @@ -153,9 +158,8 @@ def initialize_b_dec(self, activation_store):
)

@torch.no_grad()
def initialize_b_dec_with_geometric_median(self, activation_store):
def initialize_b_dec_with_geometric_median(self, all_activations):
previous_b_dec = self.b_dec.clone().cpu()
all_activations = activation_store.storage_buffer.detach().cpu()
out = compute_geometric_median(
all_activations, skip_typechecks=True, maxiter=100, per_component=False
).median
Expand All @@ -173,9 +177,8 @@ def initialize_b_dec_with_geometric_median(self, activation_store):
self.b_dec.data = out

@torch.no_grad()
def initialize_b_dec_with_mean(self, activation_store):
def initialize_b_dec_with_mean(self, all_activations):
previous_b_dec = self.b_dec.clone().cpu()
all_activations = activation_store.storage_buffer.detach().cpu()
out = all_activations.mean(dim=0)

previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
Expand Down
Loading

0 comments on commit 80f61fa

Please sign in to comment.