11"""
2- This file defines an AutoEncoder class, which also contains an implementation of neuron resampling.
2+ This file defines an AutoEncoder class, which also contains an implementation of neuron resampling.
33"""
44
5- import torch
6- import torch .nn as nn
7- import torch .nn .functional as F
5+ import torch
6+ import torch .nn as nn
7+ import torch .nn .functional as F
8+
89
910class AutoEncoder (nn .Module ):
1011 def __init__ (self , n_inputs : int , n_latents : int , lam : float = 0.003 , resampling_interval : int = 25000 ):
@@ -15,7 +16,8 @@ def __init__(self, n_inputs: int, n_latents: int, lam: float = 0.003, resampling
1516 resampling_interval: Number of training steps after which dead neurons will be resampled
1617 """
1718 super ().__init__ ()
18- self .n_inputs , self .n_latents = n_inputs , n_latents
19+ self .n_inputs = n_inputs
20+ self .n_latents = n_latents
1921 self .encoder = nn .Linear (n_inputs , n_latents )
2022 self .relu = nn .ReLU ()
2123 self .decoder = nn .Linear (n_latents , n_inputs )
@@ -37,7 +39,7 @@ def forward(self, x):
3739 'latents' : latents ,
3840 'reconst_acts' : reconstructed ,
3941 'mse_loss' : self .mse_loss (reconstructed , x ),
40- 'l1_loss' : self .l1_loss (latents )
42+ 'l1_loss' : self .l1_loss (latents ),
4143 }
4244
4345 def encode (self , x ):
@@ -66,16 +68,16 @@ def get_feature_activations(self, inputs, start_idx, end_idx):
6668 :param inputs: Input tensor of shape (..., n) where n = d_MLP. It includes batch dimensions.
6769 :param start_idx: Starting index (inclusive) of the feature subset.
6870 :param end_idx: Ending index (exclusive) of the feature subset.
69-
70- Returns the activations for the specified feature range, reducing computation by
71+
72+ Returns the activations for the specified feature range, reducing computation by
7173 only processing the necessary part of the network's weights and biases.
7274 """
7375 adjusted_inputs = inputs - self .decoder .bias # Adjust input to account for decoder bias
7476 weight_subset = self .encoder .weight [start_idx :end_idx , :].t () # Transpose the subset of weights
7577 bias_subset = self .encoder .bias [start_idx :end_idx ]
76-
78+
7779 activations = self .relu (adjusted_inputs @ weight_subset + bias_subset )
78-
80+
7981 return activations
8082
8183 @torch .no_grad ()
@@ -90,26 +92,33 @@ def remove_parallel_component_of_decoder_grad(self):
9092 """
9193 Remove the component of the gradient parallel to the decoder's weight vectors.
9294 """
93- unit_weights = F .normalize (self .decoder .weight , dim = 0 ) # \hat{b}
94- proj = (self .decoder .weight .grad * unit_weights ).sum (dim = 0 ) * unit_weights
95+ unit_weights = F .normalize (self .decoder .weight , dim = 0 ) # \hat{b}
96+ proj = (self .decoder .weight .grad * unit_weights ).sum (dim = 0 ) * unit_weights
9597 self .decoder .weight .grad = self .decoder .weight .grad - proj
9698
97- @staticmethod
99+ @staticmethod
98100 def is_dead_neuron_investigation_step (step , resampling_interval , num_resamples ):
99101 """
100102 Determine if the current step is the start of a phase for investigating dead neurons.
101103 According to Anthropic's specified policy, it occurs at odd multiples of half the resampling interval.
102104 """
103- return (step > 0 ) and step % (resampling_interval // 2 ) == 0 and (step // (resampling_interval // 2 )) % 2 != 0 and step < resampling_interval * num_resamples
105+ return (
106+ (step > 0 )
107+ and step % (resampling_interval // 2 ) == 0
108+ and (step // (resampling_interval // 2 )) % 2 != 0
109+ and step < resampling_interval * num_resamples
110+ )
104111
105112 @staticmethod
106113 def is_within_neuron_investigation_phase (step , resampling_interval , num_resamples ):
107114 """
108115 Check if the current step is within a phase where active neurons are investigated.
109116 This phase occurs in intervals defined in the specified range, starting at odd multiples of half the resampling interval.
110117 """
111- return any (milestone - resampling_interval // 2 <= step < milestone
112- for milestone in range (resampling_interval , resampling_interval * (num_resamples + 1 ), resampling_interval ))
118+ return any (
119+ milestone - resampling_interval // 2 <= step < milestone
120+ for milestone in range (resampling_interval , resampling_interval * (num_resamples + 1 ), resampling_interval )
121+ )
113122
114123 @torch .no_grad ()
115124 def initiate_dead_neurons (self ):
@@ -138,7 +147,7 @@ def resample_dead_neurons(self, data, optimizer, batch_size=8192):
138147 average_enc_norm = self ._compute_average_norm_of_alive_neurons (alive_neurons )
139148 probs = self ._compute_loss_probabilities (data , batch_size , device )
140149 selected_examples = self ._select_examples_based_on_probabilities (data , probs )
141-
150+
142151 self ._resample_neurons (selected_examples , dead_neurons_t , average_enc_norm , device )
143152 self ._update_optimizer_parameters (optimizer , dead_neurons_t )
144153
0 commit comments