77R: window_radius
88W: window_length
99L: number of autoencoder latents
10+ H: Number of features being processed in a phase
1011N: total_sampled_tokens = (num_contexts * num_sampled_tokens)
1112T: block_size (same as nanoGPT)
1213B: gpt_batch_size (same as nanoGPT)
@@ -113,7 +114,7 @@ def build(self):
113114 In each phase, compute feature activations and feature ablations for (MLP activations of) text data `self.X`.
114115 Sample context windows from this data.
115116 Next, use `get_top_activations` to get tokens (along with windows) with top activations for each feature.
116- Note that sampled tokens are the same in all phases, thanks to the use of fn_seed in `select_context_windows `.
117+ Note that sampled tokens are the same in all phases, thanks to the use of fn_seed in `_sample_context_windows `.
117118 """
118119 self .write_main_page ()
119120
@@ -136,10 +137,15 @@ def compute_context_window_data(self, feature_start_idx, feature_end_idx):
136137 context_window_data = self ._initialize_context_window_data (feature_start_idx , feature_end_idx )
137138
138139 for iter in range (self .num_batches ):
139- print (f"Computing feature activations for batch # { iter + 1 } /{ self .num_batches } " )
140+ if iter % 20 == 0 :
141+ print (f"computing feature activations for batches { iter + 1 } -{ min (iter + 20 , self .num_batches )} /{ self .num_batches } " )
140142 batch_start_idx = iter * self .gpt_batch_size
141143 batch_end_idx = (iter + 1 ) * self .gpt_batch_size
142- x , feature_activations = self ._process_batch (batch_start_idx , batch_end_idx , feature_start_idx , feature_end_idx )
144+ x , feature_activations = self ._compute_batch_feature_activations (batch_start_idx ,
145+ batch_end_idx ,
146+ feature_start_idx ,
147+ feature_end_idx )
148+ # x: (B, T), # feature_activations: (B, T, H)
143149 x_context_windows , feature_acts_context_windows = self ._sample_context_windows ( x ,
144150 feature_activations ,
145151 fn_seed = self .seed + iter )
@@ -152,15 +158,31 @@ def compute_context_window_data(self, feature_start_idx, feature_end_idx):
152158 return context_window_data
153159
154160 def compute_top_activations (self , data ):
155- """"Computes top activations given context window data"""
156- num_features_in_phase = data ["feature_acts" ].shape [- 1 ]
157- _ , topk_indices = torch .topk (data ["feature_acts" ][:, self .window_radius , :], k = self .num_top_activations , dim = 0 ) # (k, H)
158- top_acts_data = TensorDict ({
159- "tokens" : data ["tokens" ][topk_indices ].transpose (dim0 = 1 , dim1 = 2 ),
160- "feature_acts" : torch .stack ([data ["feature_acts" ][topk_indices [:, i ], :, i ] for i in range (num_features_in_phase )], dim = - 1 )
161- }, batch_size = [self .num_top_activations , self .window_length , num_features_in_phase ]) # (k, W, H)
161+ """Computes top activations of given context window data.
162+ `data` is a TensorDict with keys `tokens` and `feature_acts` of shapes (B*S, W) and (B * S, W, H) respectively."""
162163
163- return top_acts_data
164+ num_features = data ["feature_acts" ].shape [- 1 ] # Label this as H.
165+
166+ # Find the indices of the top activations at the center of the window
167+ _ , top_indices = torch .topk (data ["feature_acts" ][:, self .window_radius , :],
168+ k = self .num_top_activations , dim = 0 ) # (k, H)
169+
170+ # Prepare the tokens corresponding to the top activations
171+ top_tokens = data ["tokens" ][top_indices ].transpose (dim0 = 1 , dim1 = 2 ) # (k, W, H)
172+
173+ # Extract and stack the top feature activations for each feature across all windows
174+ top_feature_activations = torch .stack (
175+ [data ["feature_acts" ][top_indices [:, i ], :, i ] for i in range (num_features )],
176+ dim = - 1
177+ ) # (k, W, H)
178+
179+ # Bundle the top tokens and feature activations into a structured data format
180+ top_activations_data = TensorDict ({
181+ "tokens" : top_tokens ,
182+ "feature_acts" : top_feature_activations
183+ }, batch_size = [self .num_top_activations , self .window_length , num_features ]) # (k, W< H)
184+
185+ return top_activations_data
164186
165187 def write_feature_page (self , phase , h , data , top_acts_data ):
166188 """"Writes features pages for dead / alive neurons; also makes a histogram.
@@ -216,26 +238,12 @@ def _sample_context_windows(self, *args, fn_seed=0):
216238 Given tensors each of shape (B, T, ...), this function returns tensors containing
217239 windows around randomly selected tokens. The shape of the output is (B * S, W, ...),
218240 where S is the number of tokens in each context to evaluate, and W is the window size
219- (including the token itself and tokens on either side).
241+ (including the token itself and tokens on either side). By default, S = self.num_sampled_tokens,
242+ W = self.window_length.
220243
221- Parameters: #TODO: update parameters here
244+ Parameters:
222245 - args: Variable number of tensor arguments, each of shape (B, T, ...)
223- - num_sampled_tokens (int): The number of tokens in each context on which to evaluate
224- - window_radius (int): The number of tokens on either side of the sampled token
225246 - fn_seed (int, optional): Seed for random number generator, default is 0
226-
227- Returns:
228- - A list of tensors, each of shape (B * S, W, ...), where S is `num_sampled_tokens` and W is
229- the window size calculated as 2 * `window_radius` + 1.
230-
231- Raises:
232- - AssertionError: If no tensors are provided, or if the tensors do not have the required shape.
233-
234- Example usage:
235- ```
236- tensor1 = torch.randn(10, 20, 30) # Example tensor
237- windows = select_context_windows(tensor1, num_sampled_tokens=5, window_radius=2)
238- ```
239247 """
240248 if not args or not all (isinstance (tensor , torch .Tensor ) and tensor .ndim >= 2 for tensor in args ):
241249 raise ValueError ("All inputs must be torch tensors with at least 2 dimensions." )
@@ -247,22 +255,20 @@ def _sample_context_windows(self, *args, fn_seed=0):
247255
248256 torch .manual_seed (fn_seed )
249257 num_sampled_tokens = self .num_sampled_tokens
250- window_radius = self .window_radius
251- window_length = 2 * window_radius + 1
252- token_idx = torch .stack ([window_radius + torch .randperm (T - 2 * window_radius )[:num_sampled_tokens ]
253- for _ in range (B )], dim = 0 ) # (B, S)
254- window_idx = token_idx .unsqueeze (- 1 ) + torch .arange (- window_radius , window_radius + 1 ) # (B, S, W)
258+ token_idx = torch .stack ([self .window_radius + torch .randperm (T - 2 * self .window_radius )[:num_sampled_tokens ]
259+ for _ in range (B )], dim = 0 ) # (B, S) # use of torch.randperm for sampling without replacement
260+ window_idx = token_idx .unsqueeze (- 1 ) + torch .arange (- self .window_radius , self .window_radius + 1 ) # (B, S, W)
255261 batch_idx = torch .arange (B ).view (- 1 , 1 , 1 ).expand_as (window_idx ) # (B, S, W)
256262
257263 result_tensors = []
258264 for tensor in args :
259265 if tensor .ndim == 3 :
260266 L = tensor .shape [2 ]
261267 sliced_tensor = tensor [batch_idx , window_idx , :] # (B, S, W, L)
262- sliced_tensor = sliced_tensor .view (- 1 , window_length , L ) # (B *S , W, L)
268+ sliced_tensor = sliced_tensor .view (- 1 , self . window_length , L ) # (B *S , W, L)
263269 elif tensor .ndim == 2 :
264270 sliced_tensor = tensor [batch_idx , window_idx ] # (B, S, W)
265- sliced_tensor = sliced_tensor .view (- 1 , window_length ) # (B*S, W)
271+ sliced_tensor = sliced_tensor .view (- 1 , self . window_length ) # (B*S, W)
266272 else :
267273 raise ValueError ("Tensor dimensions not supported. Only 2D and 3D tensors are allowed." )
268274 result_tensors .append (sliced_tensor )
@@ -277,7 +283,7 @@ def _initialize_context_window_data(self, feature_start_idx, feature_end_idx):
277283 }, batch_size = [self .total_sampled_tokens , self .window_length ]) # (N * S, W)
278284 return context_window_data
279285
280- def _process_batch (self , batch_start_idx , batch_end_idx , feature_start_idx , feature_end_idx ):
286+ def _compute_batch_feature_activations (self , batch_start_idx , batch_end_idx , feature_start_idx , feature_end_idx ):
281287 """Computes feature activations for given batch of input text.
282288 """
283289 x = self .X [batch_start_idx :batch_end_idx ].to (self .device )
@@ -301,4 +307,5 @@ def write_main_page(self):
301307 feature_browser .build ()
302308
303309
304- #TODO: tooltip css function should be imported separately and written explicitly I think, for clarity
310+ #TODO: tooltip css function should be imported separately and written explicitly I think, for clarity
311+ # TODO: methods that need to be revisited: write_feature_page, sample_and_write.
0 commit comments