3030
3131sys .path .insert (1 , '../' )
3232from resource_loader import ResourceLoader
33- from utils .plotting_utils import make_histogram
33+ from utils .plotting_utils import make_activations_histogram , make_logits_histogram
3434
3535# hyperparameters
3636# data and model
@@ -102,6 +102,15 @@ def __init__(self, config):
102102 self .html_out = os .path .join (os .path .dirname (os .path .abspath ('.' )), 'out' , config .dataset , str (config .sae_ckpt_dir ))
103103 self .seed = config .seed
104104
105+ # create subdirectories to store logit histograms and feature activation histograms
106+ os .makedirs (os .path .join (self .html_out , 'logits_histograms' ), exist_ok = True )
107+ os .makedirs (os .path .join (self .html_out , 'activations_histograms' ), exist_ok = True )
108+
109+ # TODO: why are my logits of the order of 10, while Anthropic's are <1. Do they rescale them?
110+ # Or is it because of linear approximation to LayerNorm?
111+
112+ # self.attributed logits is a tensor of shape (n_features, vocab_size) containing logits for each feature
113+ self .attributed_logits = self .compute_logits ()
105114 self .top_logits , self .bottom_logits = self .compute_top_and_bottom_logits ()
106115
107116 print (f"Will process features in { self .num_phases } phases. Each phase will have forward pass in { self .num_batches } batches" )
@@ -124,6 +133,12 @@ def build(self):
124133 context_window_data = self .compute_context_window_data (feature_start_idx , feature_end_idx )
125134 top_acts_data = self .compute_top_activations (context_window_data )
126135 for h in range (0 , feature_end_idx - feature_start_idx ):
136+ # make and save histogram of logits for this feature
137+ feature_id = phase * self .num_features_per_phase + h
138+ make_logits_histogram (logits = self .attributed_logits [feature_id , :],
139+ feature_id = feature_id ,
140+ dirpath = self .html_out )
141+ # write the page for this feature
127142 self .write_feature_page (phase , h , context_window_data , top_acts_data )
128143
129144 # if phase == 1:
@@ -184,31 +199,39 @@ def compute_top_activations(self, data):
184199 return top_activations_data
185200
186201 @torch .no_grad ()
187- def compute_top_and_bottom_logits (self ,):
202+ def compute_logits (self ,):
188203 """
189- Computes top and bottom logits for each feature.
190- Returns (top_logits, bottom_logits). Each is of type `torch.return_types.topk`.
191- It uses the full LayerNorm instead of its approximation. # TODO: How important is that?
192- # also, this function is specific to SAEs trained on the activations of last MLP layer for now.
204+ Computes logits for each feature through path expansion approach.
205+ Returns a torch tensor of shape (num_features, vocab_size)
206+ By default, it uses full LayerNorm instead of its linear approximation. # TODO: understand if that's okay
207+ # also, this function is specific to SAEs trained on the activations of last MLP layer for now. TODO: generalize this
208+ By default, logits for each feature are shifted so that the median value is 0.
193209 """
194210 mlp_out = self .transformer .transformer .h [- 1 ].mlp .c_proj (self .autoencoder .decoder .weight .detach ().t ()) # (L, C)
195211 ln_out = self .transformer .transformer .ln_f (mlp_out ) # (L, C)
196212 logits = self .transformer .lm_head (ln_out ) # (L, V)
197- shifted_logits = (logits - logits .median (dim = 1 , keepdim = True ).values ) # (L, V)
198-
213+ attributed_logits = (logits - logits .median (dim = 1 , keepdim = True ).values ) # (L, V)
214+ return attributed_logits
215+
216+ @torch .no_grad ()
217+ def compute_top_and_bottom_logits (self ,):
218+ """
219+ Computes top and bottom logits for each feature.
220+ Returns (top_logits, bottom_logits). Each is of type `torch.return_types.topk`.
221+ """
199222 # GPT-2 tokenizer has vocab size 50257. nanoGPT sets vocab size = 50304 for higher training speed.
200223 # See https://twitter.com/karpathy/status/1621578354024677377?lang=en
201224 # Decoder will give an error if a token with id > 50256 is given, and bottom_logits may pick one of these tokens.
202225 # Therefore, set max token id to 50256 by hand.
203- shifted_logits = shifted_logits [:, :50257 ]
204-
205- top_logits = torch .topk (shifted_logits , largest = True , sorted = True , k = self .num_top_activations , dim = 1 ) # (L, k)
206- bottom_logits = torch .topk (shifted_logits , largest = False , sorted = True , k = self .num_top_activations , dim = 1 ) # (L, k)
226+ attributed_logits = self .attributed_logits [:, :50257 ]
227+ top_logits = torch .topk (attributed_logits , largest = True , sorted = True , k = self .num_top_activations , dim = 1 ) # (L, k)
228+ bottom_logits = torch .topk (attributed_logits , largest = False , sorted = True , k = self .num_top_activations , dim = 1 ) # (L, k)
207229 return top_logits , bottom_logits
208230
209231 def write_feature_page (self , phase , h , data , top_acts_data ):
210232 """"Writes features pages for dead / alive neurons; also makes a histogram.
211233 For alive features, it calls sample_and_write."""
234+
212235 curr_feature_acts_MW = data ["feature_acts" ][:, :, h ]
213236 mid_token_feature_acts_M = curr_feature_acts_MW [:, self .window_radius ]
214237 num_nonzero_acts = torch .count_nonzero (mid_token_feature_acts_M )
@@ -220,11 +243,12 @@ def write_feature_page(self, phase, h, data, top_acts_data):
220243
221244 act_density = torch .count_nonzero (curr_feature_acts_MW ) / (self .total_sampled_tokens * self .window_length ) * 100
222245 non_zero_acts = curr_feature_acts_MW [curr_feature_acts_MW != 0 ]
223- make_histogram (activations = non_zero_acts ,
246+ make_activations_histogram (activations = non_zero_acts ,
224247 density = act_density ,
225248 feature_id = feature_id ,
226249 dirpath = self .html_out )
227250
251+
228252 if num_nonzero_acts < self .num_intervals * self .samples_per_interval :
229253 write_ultralow_density_feature_page (feature_id = feature_id ,
230254 decode = self .decode ,
0 commit comments