Skip to content

Commit 53f521e

Browse files
committed
added logit histogram in feature browser
1 parent 3893fa7 commit 53f521e

File tree

6 files changed

+75
-28
lines changed

6 files changed

+75
-28
lines changed

autoencoder/feature-browser/build_website.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
sys.path.insert(1, '../')
3232
from resource_loader import ResourceLoader
33-
from utils.plotting_utils import make_histogram
33+
from utils.plotting_utils import make_activations_histogram, make_logits_histogram
3434

3535
# hyperparameters
3636
# data and model
@@ -102,6 +102,15 @@ def __init__(self, config):
102102
self.html_out = os.path.join(os.path.dirname(os.path.abspath('.')), 'out', config.dataset, str(config.sae_ckpt_dir))
103103
self.seed = config.seed
104104

105+
# create subdirectories to store logit histograms and feature activation histograms
106+
os.makedirs(os.path.join(self.html_out, 'logits_histograms'), exist_ok=True)
107+
os.makedirs(os.path.join(self.html_out, 'activations_histograms'), exist_ok=True)
108+
109+
# TODO: why are my logits of the order of 10, while Anthropic's are <1. Do they rescale them?
110+
# Or is it because of linear approximation to LayerNorm?
111+
112+
# self.attributed logits is a tensor of shape (n_features, vocab_size) containing logits for each feature
113+
self.attributed_logits = self.compute_logits()
105114
self.top_logits, self.bottom_logits = self.compute_top_and_bottom_logits()
106115

107116
print(f"Will process features in {self.num_phases} phases. Each phase will have forward pass in {self.num_batches} batches")
@@ -124,6 +133,12 @@ def build(self):
124133
context_window_data = self.compute_context_window_data(feature_start_idx, feature_end_idx)
125134
top_acts_data = self.compute_top_activations(context_window_data)
126135
for h in range(0, feature_end_idx-feature_start_idx):
136+
# make and save histogram of logits for this feature
137+
feature_id = phase * self.num_features_per_phase + h
138+
make_logits_histogram(logits=self.attributed_logits[feature_id, :],
139+
feature_id=feature_id,
140+
dirpath=self.html_out)
141+
# write the page for this feature
127142
self.write_feature_page(phase, h, context_window_data, top_acts_data)
128143

129144
# if phase == 1:
@@ -184,31 +199,39 @@ def compute_top_activations(self, data):
184199
return top_activations_data
185200

186201
@torch.no_grad()
187-
def compute_top_and_bottom_logits(self,):
202+
def compute_logits(self,):
188203
"""
189-
Computes top and bottom logits for each feature.
190-
Returns (top_logits, bottom_logits). Each is of type `torch.return_types.topk`.
191-
It uses the full LayerNorm instead of its approximation. # TODO: How important is that?
192-
# also, this function is specific to SAEs trained on the activations of last MLP layer for now.
204+
Computes logits for each feature through path expansion approach.
205+
Returns a torch tensor of shape (num_features, vocab_size)
206+
By default, it uses full LayerNorm instead of its linear approximation. # TODO: understand if that's okay
207+
# also, this function is specific to SAEs trained on the activations of last MLP layer for now. TODO: generalize this
208+
By default, logits for each feature are shifted so that the median value is 0.
193209
"""
194210
mlp_out = self.transformer.transformer.h[-1].mlp.c_proj(self.autoencoder.decoder.weight.detach().t()) # (L, C)
195211
ln_out = self.transformer.transformer.ln_f(mlp_out) # (L, C)
196212
logits = self.transformer.lm_head(ln_out) # (L, V)
197-
shifted_logits = (logits - logits.median(dim=1, keepdim=True).values) # (L, V)
198-
213+
attributed_logits = (logits - logits.median(dim=1, keepdim=True).values) # (L, V)
214+
return attributed_logits
215+
216+
@torch.no_grad()
217+
def compute_top_and_bottom_logits(self,):
218+
"""
219+
Computes top and bottom logits for each feature.
220+
Returns (top_logits, bottom_logits). Each is of type `torch.return_types.topk`.
221+
"""
199222
# GPT-2 tokenizer has vocab size 50257. nanoGPT sets vocab size = 50304 for higher training speed.
200223
# See https://twitter.com/karpathy/status/1621578354024677377?lang=en
201224
# Decoder will give an error if a token with id > 50256 is given, and bottom_logits may pick one of these tokens.
202225
# Therefore, set max token id to 50256 by hand.
203-
shifted_logits = shifted_logits[:, :50257]
204-
205-
top_logits = torch.topk(shifted_logits, largest=True, sorted=True, k=self.num_top_activations, dim=1) # (L, k)
206-
bottom_logits = torch.topk(shifted_logits, largest=False, sorted=True, k=self.num_top_activations, dim=1) # (L, k)
226+
attributed_logits = self.attributed_logits[:, :50257]
227+
top_logits = torch.topk(attributed_logits, largest=True, sorted=True, k=self.num_top_activations, dim=1) # (L, k)
228+
bottom_logits = torch.topk(attributed_logits, largest=False, sorted=True, k=self.num_top_activations, dim=1) # (L, k)
207229
return top_logits, bottom_logits
208230

209231
def write_feature_page(self, phase, h, data, top_acts_data):
210232
""""Writes features pages for dead / alive neurons; also makes a histogram.
211233
For alive features, it calls sample_and_write."""
234+
212235
curr_feature_acts_MW = data["feature_acts"][:, :, h]
213236
mid_token_feature_acts_M = curr_feature_acts_MW[:, self.window_radius]
214237
num_nonzero_acts = torch.count_nonzero(mid_token_feature_acts_M)
@@ -220,11 +243,12 @@ def write_feature_page(self, phase, h, data, top_acts_data):
220243

221244
act_density = torch.count_nonzero(curr_feature_acts_MW) / (self.total_sampled_tokens * self.window_length) * 100
222245
non_zero_acts = curr_feature_acts_MW[curr_feature_acts_MW != 0]
223-
make_histogram(activations=non_zero_acts,
246+
make_activations_histogram(activations=non_zero_acts,
224247
density=act_density,
225248
feature_id=feature_id,
226249
dirpath=self.html_out)
227250

251+
228252
if num_nonzero_acts < self.num_intervals * self.samples_per_interval:
229253
write_ultralow_density_feature_page(feature_id=feature_id,
230254
decode=self.decode,

autoencoder/feature-browser/subpages.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def write_feature_page_header():
3939
flex: 1; /* Make each column take up equal space */
4040
padding: 0 10px; /* Add some padding */
4141
border: 2px solid #ccc;
42-
background-color: #f9f9f9;
42+
background-color: #ffffff;
4343
}}
4444
h2 {{
4545
margin-bottom: 15px;
@@ -134,21 +134,20 @@ def write_activations_section(decode, examples_data):
134134

135135

136136
def include_feature_density_histogram(feature_id, dirpath=None):
137-
if os.path.exists(os.path.join(dirpath, 'histograms', f'{feature_id}.png')):
137+
if os.path.exists(os.path.join(dirpath, 'activations_histograms', f'{feature_id}.png')):
138138
feature_density_histogram = f"""
139139
<div class="image-container">
140-
<img src="../histograms/{feature_id}.png" alt="Feature Activations Histogram">
140+
<img src="../activations_histograms/{feature_id}.png" alt="Feature Activations Histogram">
141141
</div>"""
142142
return feature_density_histogram
143143
else:
144144
return ""
145145

146146
def include_logits_histogram(feature_id, dirpath=None):
147-
# TODO: replace histograms with logits in the path below.
148-
if os.path.exists(os.path.join(dirpath, 'histograms', f'{feature_id}.png')):
147+
if os.path.exists(os.path.join(dirpath, 'logits_histograms', f'{feature_id}.png')):
149148
logits_histogram = f"""
150149
<div class="second-image-container">
151-
<img src="../histograms/{feature_id}.png" alt="Logits Histogram" width="100" height="50">
150+
<img src="../logits_histograms/{feature_id}.png" alt="Logits Histogram" width="400" height="200">
152151
</div>
153152
</div>
154153
"""
@@ -255,10 +254,10 @@ def write_ultralow_density_feature_page(feature_id, decode, top_acts_data, dirpa
255254
html_content.append(write_feature_page_header())
256255

257256
# add histogram of feature activations
258-
if os.path.exists(os.path.join(dirpath, 'histograms', f'{feature_id}.png')):
257+
if os.path.exists(os.path.join(dirpath, 'activations_histograms', f'{feature_id}.png')):
259258
html_content.append(f"""<div class="content-container">
260259
<div class="image-container">
261-
<img src="../histograms/{feature_id}.png" alt="Feature Activations Histogram">
260+
<img src="../activations_histograms/{feature_id}.png" alt="Feature Activations Histogram">
262261
</div>""")
263262

264263
# add feature #, and the information that it is an ultralow density neuron

autoencoder/resource_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, dataset, gpt_ckpt_dir, device='cpu', mode="train", sae_ckpt_d
4343
assert sae_ckpt_dir, "A path to autoencoder checkpoint must be given"
4444
self.sae_ckpt_dir = sae_ckpt_dir
4545
self.autoencoder = self.load_autoencoder_model()
46+
self.autoencoder.eval() # note that if we load an autoencoder to resume training, we must not do this
4647

4748
def load_text_data(self):
4849
"""Loads the text data from the specified dataset."""

autoencoder/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
from autoencoder import AutoEncoder
1212
from resource_loader import ResourceLoader
13-
from utils.plotting_utils import make_histogram_image
13+
from utils.plotting_utils import make_density_histogram
1414

1515
## hyperparameters
1616
# dataset and model
@@ -155,7 +155,7 @@
155155

156156
# compute feature densities and plot feature density histogram
157157
log_feat_acts_density = np.log10(feat_acts_count[feat_acts_count != 0]/(eval_iters * eval_batch_size * gpt.config.block_size)) # (n_features,)
158-
feat_density_historgram = make_histogram_image(log_feat_acts_density)
158+
feat_density_historgram = make_density_histogram(log_feat_acts_density)
159159

160160
# take mean of all loss values by dividing by the number of evaluation batches; also log more metrics
161161
log_dict = {key: val/eval_iters for key, val in log_dict.items()}

autoencoder/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .plotting_utils import make_histogram_image
1+
from .plotting_utils import make_density_histogram, make_activations_histogram, make_logits_histogram

autoencoder/utils/plotting_utils.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
"""
2+
Three different histogram functions. The difference lies in whether to save the histogram image on disk or not,
3+
color scheme and axes labels.
4+
These can perhaps be combined into one function, but leaving it as it is for now.
5+
"""
16
import matplotlib.pyplot as plt
27
from PIL import Image
38
from io import BytesIO
49
import torch
510
import os
611

7-
def make_histogram_image(data, bins='auto'):
8-
"""Generates a histogram image from the provided data."""
12+
def make_density_histogram(data, bins='auto'):
13+
"""Makes a histogram image from the provided data and returns it.
14+
We use it in train.py to plot feature density histograms and log them with W&B."""
915
fig, ax = plt.subplots()
1016
ax.hist(data, bins=bins)
1117
ax.set_title('Histogram')
@@ -19,7 +25,9 @@ def make_histogram_image(data, bins='auto'):
1925
plt.close(fig) # close the figure to free memory
2026
return image
2127

22-
def make_histogram(activations, density, feature_id, dirpath=None):
28+
def make_activations_histogram(activations, density, feature_id, dirpath=None):
29+
"""makes a histogram of activations and saves it on the disk
30+
we later include the histogram in the feature browser"""
2331
if isinstance(activations, torch.Tensor):
2432
activations = activations.cpu().numpy()
2533
plt.hist(activations, bins='auto') # You can adjust the number of bins as needed
@@ -28,5 +36,20 @@ def make_histogram(activations, density, feature_id, dirpath=None):
2836
plt.ylabel('Frequency')
2937

3038
# Save the histogram as an image
31-
plt.savefig(os.path.join(dirpath, 'histograms', f'{feature_id}.png'))
39+
image_path = os.path.join(dirpath, 'activations_histograms', f'{feature_id}.png')
40+
plt.savefig(image_path)
41+
plt.close()
42+
43+
def make_logits_histogram(logits, feature_id, dirpath=None):
44+
"""
45+
Makes a histogram of logits for a given feature and saves it as a PNG file
46+
Input:
47+
logits: a torch tensor of shape (vocab_size,)
48+
feature_id: int
49+
dirpath: histogram is saved as dirpath/logits_histograms/feature_id.png
50+
"""
51+
plt.hist(logits, bins='auto') # You can adjust the number of bins as needed
52+
53+
image_path = os.path.join(dirpath, 'logits_histograms', f'{feature_id}.png')
54+
plt.savefig(image_path)
3255
plt.close()

0 commit comments

Comments
 (0)