Skip to content

Commit acf8d6b

Browse files
committed
more code cleanup in FeatureBrowser class; committing at an intermediate state
1 parent 21ab744 commit acf8d6b

File tree

1 file changed

+44
-37
lines changed

1 file changed

+44
-37
lines changed

autoencoder/feature-browser/build_website.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
R: window_radius
88
W: window_length
99
L: number of autoencoder latents
10+
H: Number of features being processed in a phase
1011
N: total_sampled_tokens = (num_contexts * num_sampled_tokens)
1112
T: block_size (same as nanoGPT)
1213
B: gpt_batch_size (same as nanoGPT)
@@ -113,7 +114,7 @@ def build(self):
113114
In each phase, compute feature activations and feature ablations for (MLP activations of) text data `self.X`.
114115
Sample context windows from this data.
115116
Next, use `get_top_activations` to get tokens (along with windows) with top activations for each feature.
116-
Note that sampled tokens are the same in all phases, thanks to the use of fn_seed in `select_context_windows`.
117+
Note that sampled tokens are the same in all phases, thanks to the use of fn_seed in `_sample_context_windows`.
117118
"""
118119
self.write_main_page()
119120

@@ -136,10 +137,15 @@ def compute_context_window_data(self, feature_start_idx, feature_end_idx):
136137
context_window_data = self._initialize_context_window_data(feature_start_idx, feature_end_idx)
137138

138139
for iter in range(self.num_batches):
139-
print(f"Computing feature activations for batch # {iter+1}/{self.num_batches}")
140+
if iter % 20 == 0:
141+
print(f"computing feature activations for batches {iter+1}-{min(iter+20, self.num_batches)}/{self.num_batches}")
140142
batch_start_idx = iter * self.gpt_batch_size
141143
batch_end_idx = (iter + 1) * self.gpt_batch_size
142-
x, feature_activations = self._process_batch(batch_start_idx, batch_end_idx, feature_start_idx, feature_end_idx)
144+
x, feature_activations = self._compute_batch_feature_activations(batch_start_idx,
145+
batch_end_idx,
146+
feature_start_idx,
147+
feature_end_idx)
148+
# x: (B, T), # feature_activations: (B, T, H)
143149
x_context_windows, feature_acts_context_windows = self._sample_context_windows( x,
144150
feature_activations,
145151
fn_seed=self.seed+iter)
@@ -152,15 +158,31 @@ def compute_context_window_data(self, feature_start_idx, feature_end_idx):
152158
return context_window_data
153159

154160
def compute_top_activations(self, data):
155-
""""Computes top activations given context window data"""
156-
num_features_in_phase = data["feature_acts"].shape[-1]
157-
_, topk_indices = torch.topk(data["feature_acts"][:, self.window_radius, :], k=self.num_top_activations, dim=0) # (k, H)
158-
top_acts_data = TensorDict({
159-
"tokens": data["tokens"][topk_indices].transpose(dim0=1, dim1=2),
160-
"feature_acts": torch.stack([data["feature_acts"][topk_indices[:, i], :, i] for i in range(num_features_in_phase)], dim=-1)
161-
}, batch_size=[self.num_top_activations, self.window_length, num_features_in_phase]) # (k, W, H)
161+
"""Computes top activations of given context window data.
162+
`data` is a TensorDict with keys `tokens` and `feature_acts` of shapes (B*S, W) and (B * S, W, H) respectively."""
162163

163-
return top_acts_data
164+
num_features = data["feature_acts"].shape[-1] # Label this as H.
165+
166+
# Find the indices of the top activations at the center of the window
167+
_, top_indices = torch.topk(data["feature_acts"][:, self.window_radius, :],
168+
k=self.num_top_activations, dim=0) # (k, H)
169+
170+
# Prepare the tokens corresponding to the top activations
171+
top_tokens = data["tokens"][top_indices].transpose(dim0=1, dim1=2) # (k, W, H)
172+
173+
# Extract and stack the top feature activations for each feature across all windows
174+
top_feature_activations = torch.stack(
175+
[data["feature_acts"][top_indices[:, i], :, i] for i in range(num_features)],
176+
dim=-1
177+
) # (k, W, H)
178+
179+
# Bundle the top tokens and feature activations into a structured data format
180+
top_activations_data = TensorDict({
181+
"tokens": top_tokens,
182+
"feature_acts": top_feature_activations
183+
}, batch_size=[self.num_top_activations, self.window_length, num_features]) # (k, W< H)
184+
185+
return top_activations_data
164186

165187
def write_feature_page(self, phase, h, data, top_acts_data):
166188
""""Writes features pages for dead / alive neurons; also makes a histogram.
@@ -216,26 +238,12 @@ def _sample_context_windows(self, *args, fn_seed=0):
216238
Given tensors each of shape (B, T, ...), this function returns tensors containing
217239
windows around randomly selected tokens. The shape of the output is (B * S, W, ...),
218240
where S is the number of tokens in each context to evaluate, and W is the window size
219-
(including the token itself and tokens on either side).
241+
(including the token itself and tokens on either side). By default, S = self.num_sampled_tokens,
242+
W = self.window_length.
220243
221-
Parameters: #TODO: update parameters here
244+
Parameters:
222245
- args: Variable number of tensor arguments, each of shape (B, T, ...)
223-
- num_sampled_tokens (int): The number of tokens in each context on which to evaluate
224-
- window_radius (int): The number of tokens on either side of the sampled token
225246
- fn_seed (int, optional): Seed for random number generator, default is 0
226-
227-
Returns:
228-
- A list of tensors, each of shape (B * S, W, ...), where S is `num_sampled_tokens` and W is
229-
the window size calculated as 2 * `window_radius` + 1.
230-
231-
Raises:
232-
- AssertionError: If no tensors are provided, or if the tensors do not have the required shape.
233-
234-
Example usage:
235-
```
236-
tensor1 = torch.randn(10, 20, 30) # Example tensor
237-
windows = select_context_windows(tensor1, num_sampled_tokens=5, window_radius=2)
238-
```
239247
"""
240248
if not args or not all(isinstance(tensor, torch.Tensor) and tensor.ndim >= 2 for tensor in args):
241249
raise ValueError("All inputs must be torch tensors with at least 2 dimensions.")
@@ -247,22 +255,20 @@ def _sample_context_windows(self, *args, fn_seed=0):
247255

248256
torch.manual_seed(fn_seed)
249257
num_sampled_tokens=self.num_sampled_tokens
250-
window_radius=self.window_radius
251-
window_length = 2 * window_radius + 1
252-
token_idx = torch.stack([window_radius + torch.randperm(T - 2 * window_radius)[:num_sampled_tokens]
253-
for _ in range(B)], dim=0) # (B, S)
254-
window_idx = token_idx.unsqueeze(-1) + torch.arange(-window_radius, window_radius + 1) # (B, S, W)
258+
token_idx = torch.stack([self.window_radius + torch.randperm(T - 2 * self.window_radius)[:num_sampled_tokens]
259+
for _ in range(B)], dim=0) # (B, S) # use of torch.randperm for sampling without replacement
260+
window_idx = token_idx.unsqueeze(-1) + torch.arange(-self.window_radius, self.window_radius + 1) # (B, S, W)
255261
batch_idx = torch.arange(B).view(-1, 1, 1).expand_as(window_idx) # (B, S, W)
256262

257263
result_tensors = []
258264
for tensor in args:
259265
if tensor.ndim == 3:
260266
L = tensor.shape[2]
261267
sliced_tensor = tensor[batch_idx, window_idx, :] # (B, S, W, L)
262-
sliced_tensor = sliced_tensor.view(-1, window_length, L) # (B *S , W, L)
268+
sliced_tensor = sliced_tensor.view(-1, self.window_length, L) # (B *S , W, L)
263269
elif tensor.ndim == 2:
264270
sliced_tensor = tensor[batch_idx, window_idx] # (B, S, W)
265-
sliced_tensor = sliced_tensor.view(-1, window_length) # (B*S, W)
271+
sliced_tensor = sliced_tensor.view(-1, self.window_length) # (B*S, W)
266272
else:
267273
raise ValueError("Tensor dimensions not supported. Only 2D and 3D tensors are allowed.")
268274
result_tensors.append(sliced_tensor)
@@ -277,7 +283,7 @@ def _initialize_context_window_data(self, feature_start_idx, feature_end_idx):
277283
}, batch_size=[self.total_sampled_tokens, self.window_length]) # (N * S, W)
278284
return context_window_data
279285

280-
def _process_batch(self, batch_start_idx, batch_end_idx, feature_start_idx, feature_end_idx):
286+
def _compute_batch_feature_activations(self, batch_start_idx, batch_end_idx, feature_start_idx, feature_end_idx):
281287
"""Computes feature activations for given batch of input text.
282288
"""
283289
x = self.X[batch_start_idx:batch_end_idx].to(self.device)
@@ -301,4 +307,5 @@ def write_main_page(self):
301307
feature_browser.build()
302308

303309

304-
#TODO: tooltip css function should be imported separately and written explicitly I think, for clarity
310+
#TODO: tooltip css function should be imported separately and written explicitly I think, for clarity
311+
# TODO: methods that need to be revisited: write_feature_page, sample_and_write.

0 commit comments

Comments
 (0)