Skip to content

Commit b95628b

Browse files
committed
Lots of changes and cleanup, still many TODOs
1 parent aa667b1 commit b95628b

File tree

13 files changed

+435
-347
lines changed

13 files changed

+435
-347
lines changed

autoencoder/autoencoder.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""
2-
This file defines an AutoEncoder class, which also contains an implementation of neuron resampling.
2+
This file defines an AutoEncoder class, which also contains an implementation of neuron resampling.
33
"""
44

5-
import torch
6-
import torch.nn as nn
7-
import torch.nn.functional as F
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
89

910
class AutoEncoder(nn.Module):
1011
def __init__(self, n_inputs: int, n_latents: int, lam: float = 0.003, resampling_interval: int = 25000):
@@ -15,7 +16,8 @@ def __init__(self, n_inputs: int, n_latents: int, lam: float = 0.003, resampling
1516
resampling_interval: Number of training steps after which dead neurons will be resampled
1617
"""
1718
super().__init__()
18-
self.n_inputs, self.n_latents = n_inputs, n_latents
19+
self.n_inputs = n_inputs
20+
self.n_latents = n_latents
1921
self.encoder = nn.Linear(n_inputs, n_latents)
2022
self.relu = nn.ReLU()
2123
self.decoder = nn.Linear(n_latents, n_inputs)
@@ -37,7 +39,7 @@ def forward(self, x):
3739
'latents': latents,
3840
'reconst_acts': reconstructed,
3941
'mse_loss': self.mse_loss(reconstructed, x),
40-
'l1_loss': self.l1_loss(latents)
42+
'l1_loss': self.l1_loss(latents),
4143
}
4244

4345
def encode(self, x):
@@ -66,16 +68,16 @@ def get_feature_activations(self, inputs, start_idx, end_idx):
6668
:param inputs: Input tensor of shape (..., n) where n = d_MLP. It includes batch dimensions.
6769
:param start_idx: Starting index (inclusive) of the feature subset.
6870
:param end_idx: Ending index (exclusive) of the feature subset.
69-
70-
Returns the activations for the specified feature range, reducing computation by
71+
72+
Returns the activations for the specified feature range, reducing computation by
7173
only processing the necessary part of the network's weights and biases.
7274
"""
7375
adjusted_inputs = inputs - self.decoder.bias # Adjust input to account for decoder bias
7476
weight_subset = self.encoder.weight[start_idx:end_idx, :].t() # Transpose the subset of weights
7577
bias_subset = self.encoder.bias[start_idx:end_idx]
76-
78+
7779
activations = self.relu(adjusted_inputs @ weight_subset + bias_subset)
78-
80+
7981
return activations
8082

8183
@torch.no_grad()
@@ -90,26 +92,33 @@ def remove_parallel_component_of_decoder_grad(self):
9092
"""
9193
Remove the component of the gradient parallel to the decoder's weight vectors.
9294
"""
93-
unit_weights = F.normalize(self.decoder.weight, dim=0) # \hat{b}
94-
proj = (self.decoder.weight.grad * unit_weights).sum(dim=0) * unit_weights
95+
unit_weights = F.normalize(self.decoder.weight, dim=0) # \hat{b}
96+
proj = (self.decoder.weight.grad * unit_weights).sum(dim=0) * unit_weights
9597
self.decoder.weight.grad = self.decoder.weight.grad - proj
9698

97-
@staticmethod
99+
@staticmethod
98100
def is_dead_neuron_investigation_step(step, resampling_interval, num_resamples):
99101
"""
100102
Determine if the current step is the start of a phase for investigating dead neurons.
101103
According to Anthropic's specified policy, it occurs at odd multiples of half the resampling interval.
102104
"""
103-
return (step > 0) and step % (resampling_interval // 2) == 0 and (step // (resampling_interval // 2)) % 2 != 0 and step < resampling_interval * num_resamples
105+
return (
106+
(step > 0)
107+
and step % (resampling_interval // 2) == 0
108+
and (step // (resampling_interval // 2)) % 2 != 0
109+
and step < resampling_interval * num_resamples
110+
)
104111

105112
@staticmethod
106113
def is_within_neuron_investigation_phase(step, resampling_interval, num_resamples):
107114
"""
108115
Check if the current step is within a phase where active neurons are investigated.
109116
This phase occurs in intervals defined in the specified range, starting at odd multiples of half the resampling interval.
110117
"""
111-
return any(milestone - resampling_interval // 2 <= step < milestone
112-
for milestone in range(resampling_interval, resampling_interval * (num_resamples + 1), resampling_interval))
118+
return any(
119+
milestone - resampling_interval // 2 <= step < milestone
120+
for milestone in range(resampling_interval, resampling_interval * (num_resamples + 1), resampling_interval)
121+
)
113122

114123
@torch.no_grad()
115124
def initiate_dead_neurons(self):
@@ -138,7 +147,7 @@ def resample_dead_neurons(self, data, optimizer, batch_size=8192):
138147
average_enc_norm = self._compute_average_norm_of_alive_neurons(alive_neurons)
139148
probs = self._compute_loss_probabilities(data, batch_size, device)
140149
selected_examples = self._select_examples_based_on_probabilities(data, probs)
141-
150+
142151
self._resample_neurons(selected_examples, dead_neurons_t, average_enc_norm, device)
143152
self._update_optimizer_parameters(optimizer, dead_neurons_t)
144153

autoencoder/configurator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
for arg in sys.argv[1:]:
2121
if '=' not in arg:
2222
# assume it's the name of a config file
23-
assert not arg.startswith('--')
23+
assert not arg.startswith('--'), f'{arg} is not a valid config file'
2424
config_file = arg
2525
print(f"Overriding config with {config_file}:")
2626
with open(config_file) as f:
2727
print(f.read())
2828
exec(open(config_file).read())
2929
else:
3030
# assume it's a --key=value argument
31-
assert arg.startswith('--')
31+
assert arg.startswith('--'), f'{arg} is not a valid config file'
3232
key, val = arg.split('=')
3333
key = key[2:]
3434
if key in globals():
@@ -39,7 +39,7 @@
3939
# if that goes wrong, just use the string
4040
attempt = val
4141
# ensure the types match ok
42-
assert type(attempt) == type(globals()[key])
42+
assert type(attempt) == type(globals()[key]), f'{key} is {type(globals()[key])} and {attempt} is {type(attempt)}'
4343
# cross fingers
4444
print(f"Overriding: {key} = {attempt}")
4545
globals()[key] = attempt

0 commit comments

Comments
 (0)