Skip to content

Commit cb8609b

Browse files
committed
Lots of small cleanups
1 parent cb4ae42 commit cb8609b

File tree

6 files changed

+34
-183
lines changed

6 files changed

+34
-183
lines changed

autoencoder/feature-browser/build_website.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"""
2121

2222
import logging
23-
from tqdm.auto import trange, tqdm
23+
from tqdm.auto import trange
2424
from dataclasses import dataclass
2525
import torch
2626
from tensordict import TensorDict
@@ -37,8 +37,8 @@
3737
# hyperparameters
3838
# data and model
3939
dataset = 'openwebtext'
40-
gpt_ckpt_dir = 'out'
41-
sae_ckpt_dir = 0.0 # subdirectory containing the specific model to consider
40+
gpt_ckpt_dir = 'MUST_BE_PROVIDED'
41+
sae_ckpt_dir = 'MUST_BE_PROVIDED' # subdirectory containing the specific model to consider
4242
# feature page hyperparameter
4343
num_contexts = 10000
4444
num_sampled_tokens = 10 # number of tokens in each context on which feature activations will be computed
@@ -50,29 +50,29 @@
5050
gpt_batch_size = 156
5151
num_phases = 52 # due to memory constraints, it's useful to process features in phases.
5252
# system
53-
device = 'cuda' # change it to cpu
53+
device = 'mps'
5454
# reproducibility
5555
seed = 1442
5656

5757

5858
@dataclass
5959
class FeatureBrowserConfig:
6060
# dataset and model
61-
dataset: str = "openwebtext"
62-
gpt_ckpt_dir: str = "out"
63-
sae_ckpt_dir: str = "out"
61+
dataset: str
62+
gpt_ckpt_dir: str
63+
sae_ckpt_dir: str
6464
# feature browser hyperparameters
65-
num_contexts: int = int(1e6)
66-
num_sampled_tokens: int = 10
67-
window_radius: int = 4
68-
num_top_activations: int = 10
69-
num_intervals: int = 12
70-
samples_per_interval: int = 5
65+
num_contexts: int
66+
num_sampled_tokens: int
67+
window_radius: int
68+
num_top_activations: int
69+
num_intervals: int
70+
samples_per_interval: int
7171
# processing hyperparameters
72-
seed: int = 0
73-
device: str = "cpu"
74-
gpt_batch_size: int = 156
75-
num_phases: int = 52
72+
seed: int
73+
device: str
74+
gpt_batch_size: int
75+
num_phases: int
7676

7777

7878
class FeatureBrowser(ResourceLoader):
@@ -86,7 +86,6 @@ def __init__(self, config):
8686
)
8787

8888
# retrieve feature browser hyperparameters from config
89-
# self.num_contexts = config.num_contexts
9089
self.num_sampled_tokens = config.num_sampled_tokens
9190
self.window_radius = config.window_radius
9291
self.num_top_activations = num_top_activations
@@ -136,11 +135,15 @@ def build(self):
136135

137136
for phase in trange(self.num_phases, desc='processing features in phases'):
138137
feature_start_idx = phase * self.num_features_per_phase
139-
feature_end_idx = min((phase + 1) * self.num_features_per_phase, self.n_features) # 4096 features in the latent space of the autoencoder
140-
# logging.info(f'working on features # {feature_start_idx} - {feature_end_idx} in phase {phase + 1}/{self.num_phases}')
138+
feature_end_idx = min((phase + 1) * self.num_features_per_phase, self.n_features)
139+
140+
if feature_start_idx >= feature_end_idx:
141+
# TODO: Adjust the feature selection logic so this never happens, just use more_itertools
142+
continue
143+
141144
context_window_data = self.compute_context_window_data(feature_start_idx, feature_end_idx)
142145
top_acts_data = self.compute_top_activations(context_window_data)
143-
for h in trange(0, feature_end_idx - feature_start_idx, desc='making histograms'):
146+
for h in trange(0, feature_end_idx - feature_start_idx, desc='making histograms', disable=True):
144147
# make and save histogram of logits for this feature
145148
feature_id = phase * self.num_features_per_phase + h
146149
make_logits_histogram(logits=self.attributed_logits[feature_id, :], feature_id=feature_id, dirpath=self.html_out)
@@ -156,9 +159,7 @@ def compute_context_window_data(self, feature_start_idx, feature_end_idx):
156159
This should probably also include feature ablations."""
157160
context_window_data = self._initialize_context_window_data(feature_start_idx, feature_end_idx)
158161

159-
for iter in trange(self.num_batches, desc='computing feature activations per batch'):
160-
# if iter % 20 == 0:
161-
# logging.info(f"computing feature activations for batches {iter+1}-{min(iter+20, self.num_batches)}/{self.num_batches}")
162+
for iter in trange(self.num_batches, desc='computing feature activations per batch', disable=True):
162163
batch_start_idx = iter * self.gpt_batch_size
163164
batch_end_idx = (iter + 1) * self.gpt_batch_size
164165
x, feature_activations, logits_difference_storage = self._compute_batch_feature_activations(
@@ -329,10 +330,12 @@ def _sample_context_windows(self, *args, fn_seed=0):
329330

330331
result_tensors = []
331332
for tensor in args:
333+
assert tensor.numel() > 0, "Tensor has 0 elements"
334+
332335
if tensor.ndim == 3:
333336
L = tensor.shape[2]
334337
sliced_tensor = tensor[batch_idx, window_idx, :] # (B, S, W, L)
335-
sliced_tensor = sliced_tensor.view(-1, self.window_length, L) # (B *S , W, L)
338+
sliced_tensor = sliced_tensor.view(-1, self.window_length, L) # (B*S , W, L)
336339
elif tensor.ndim == 2:
337340
sliced_tensor = tensor[batch_idx, window_idx] # (B, S, W)
338341
sliced_tensor = sliced_tensor.view(-1, self.window_length) # (B*S, W)
@@ -375,9 +378,9 @@ def _compute_batch_feature_activations(self, batch_start_idx, batch_end_idx, fea
375378
# TODO: do I need to center the median at 0 before computing differences?
376379
# Otherwise, the probability of sampling the token can probably not be compared through the logit weight alone.
377380
logits_difference_storage = torch.zeros(B, T, H, device=self.device) # (B, T, H)
378-
for h in trange(H, desc='computing logits with replacement tensor'):
381+
for h in trange(H, desc='computing logits with replacement tensor', disable=True):
379382
# on CPU, each forward pass takes ~12 seconds
380-
# on MPS, each forward pass takes ~200 ms (!!!)
383+
# on MPS, each forward pass takes ~200 ms (!!)
381384
feat_ablation_logits, _ = self.transformer(x, y, mode="replace", replacement_tensor=feature_ablations[:, :, :, h]) # (B, T, V)
382385

383386
logits_difference = original_logits - feat_ablation_logits # (B, T, V)

autoencoder/resource_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def init_autoencoder_data_info(self):
9696
def load_autoencoder_model(self):
9797
"""Loads the AutoEncoder model with pre-trained weights"""
9898
autoencoder_path = os.path.join(self.base_dir, "autoencoder", "out", self.dataset, self.sae_ckpt_dir)
99-
autoencoder_ckpt = torch.load(os.path.join(autoencoder_path, 'ckpt.pt'), map_location=self.device)
99+
autoencoder_ckpt = torch.load(os.path.join(autoencoder_path, 'ckpt.pt'), map_location=self.device, weights_only=False)
100100
state_dict = autoencoder_ckpt['autoencoder']
101101
n_features, n_ffwd = state_dict['encoder.weight'].shape # H, F
102102
l1_coeff = autoencoder_ckpt['config']['l1_coeff']

autoencoder/train_autoencoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
dataset = 'openwebtext'
1919
gpt_ckpt_dir = 'out'
2020
# training
21+
# TODO: Rename n_features to be latent_multiple, and base it off the gpt model
2122
n_features = 4096 # aka n_latents
2223
batch_size = 8192 # batch size for autoencoder training
2324
l1_coeff = 3e-3
@@ -135,7 +136,7 @@
135136
_, nll_loss = gpt(x, y)
136137
mlp_acts = gpt.mlp_activation_hooks[0]
137138
gpt.clear_mlp_activation_hooks() # free up memory
138-
_, ablated_loss = gpt(x, y, mode="replace")
139+
# _, ablated_loss = gpt(x, y, mode="replace")
139140

140141
with torch.no_grad():
141142
autoencoder_output = autoencoder(mlp_acts)
@@ -151,7 +152,7 @@
151152
log_dict['losses/autoencoder_loss'] += autoencoder_output['loss'].item()
152153
log_dict['losses/reconstruction_loss'] += autoencoder_output['mse_loss'].item()
153154
log_dict['losses/l1_norm'] += autoencoder_output['l1_loss'].item()
154-
log_dict['losses/nll_score'] += (nll_loss - reconstructed_nll).item() / (nll_loss - ablated_loss).item()
155+
# log_dict['losses/nll_score'] += (nll_loss - reconstructed_nll).item() / (nll_loss - ablated_loss).item()
155156

156157
# compute feature densities and plot feature density histogram
157158
log_feat_acts_density = np.log10(

transformer/config/train_gpt2.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

transformer/model.py

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
import math
11-
import inspect
1211
from dataclasses import dataclass
1312

1413
import torch
@@ -167,99 +166,6 @@ def _init_weights(self, module):
167166
elif isinstance(module, nn.Embedding):
168167
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169168

170-
def forward(self, idx, targets=None):
171-
device = idx.device
172-
b, t = idx.size()
173-
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174-
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175-
176-
# forward the GPT model itself
177-
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178-
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179-
x = self.transformer.drop(tok_emb + pos_emb)
180-
for block in self.transformer.h:
181-
x = block(x)
182-
x = self.transformer.ln_f(x)
183-
184-
if targets is not None:
185-
# if we are given some desired targets also calculate the loss
186-
logits = self.lm_head(x)
187-
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188-
else:
189-
# inference-time mini-optimization: only forward the lm_head on the very last position
190-
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191-
loss = None
192-
193-
return logits, loss
194-
195-
def crop_block_size(self, block_size):
196-
# model surgery to decrease the block size if necessary
197-
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198-
# but want to use a smaller block size for some smaller, simpler model
199-
assert block_size <= self.config.block_size
200-
self.config.block_size = block_size
201-
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202-
for block in self.transformer.h:
203-
if hasattr(block.attn, 'bias'):
204-
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205-
206-
@classmethod
207-
def from_pretrained(cls, model_type, override_args=None):
208-
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209-
override_args = override_args or {} # default to empty dict
210-
# only dropout can be overridden see more notes below
211-
assert all(k == 'dropout' for k in override_args)
212-
from transformers import GPT2LMHeadModel
213-
print("loading weights from pretrained gpt: %s" % model_type)
214-
215-
# n_layer, n_head and n_embd are determined from model_type
216-
config_args = {
217-
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
218-
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219-
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220-
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221-
}[model_type]
222-
print("forcing vocab_size=50257, block_size=1024, bias=True")
223-
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224-
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225-
config_args['bias'] = True # always True for GPT model checkpoints
226-
# we can override the dropout rate, if desired
227-
if 'dropout' in override_args:
228-
print(f"overriding dropout rate to {override_args['dropout']}")
229-
config_args['dropout'] = override_args['dropout']
230-
# create a from-scratch initialized minGPT model
231-
config = GPTConfig(**config_args)
232-
model = GPT(config)
233-
sd = model.state_dict()
234-
sd_keys = sd.keys()
235-
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236-
237-
# init a huggingface/transformers model
238-
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239-
sd_hf = model_hf.state_dict()
240-
241-
# copy while ensuring all of the parameters are aligned and match in names and shapes
242-
sd_keys_hf = sd_hf.keys()
243-
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244-
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245-
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246-
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247-
# this means that we have to transpose these weights when we import them
248-
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249-
for k in sd_keys_hf:
250-
if any(k.endswith(w) for w in transposed):
251-
# special treatment for the Conv1D weights we need to transpose
252-
assert sd_hf[k].shape[::-1] == sd[k].shape
253-
with torch.no_grad():
254-
sd[k].copy_(sd_hf[k].t())
255-
else:
256-
# vanilla copy over the other parameters
257-
assert sd_hf[k].shape == sd[k].shape
258-
with torch.no_grad():
259-
sd[k].copy_(sd_hf[k])
260-
261-
return model
262-
263169
def configure_optimizers(self, weight_decay, learning_rate, betas):
264170
# start with all of the candidate parameters
265171
param_dict = {pn: p for pn, p in self.named_parameters()}
@@ -278,12 +184,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas):
278184
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279185
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280186

281-
# Create AdamW optimizer and use the fused version if it is available
282-
# fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
283-
# use_fused = fused_available and device_type == 'cuda'
284-
# extra_args = dict(fused=True) if use_fused else dict()
285187
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=False)
286-
# print(f"using fused AdamW: {use_fused}")
287188

288189
return optimizer
289190

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -153,37 +153,8 @@ def get_batch(split):
153153

154154
elif init_from == 'resume':
155155
raise DeprecationWarning('init from is deprecated')
156-
print(f"Resuming training from {out_dir}")
157-
# resume training from a checkpoint.
158-
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
159-
checkpoint = torch.load(ckpt_path, map_location=device)
160-
checkpoint_model_args = checkpoint['model_args']
161-
# force these config attributes to be equal otherwise we can't even resume training
162-
# the rest of the attributes (e.g. dropout) can stay as desired from command line
163-
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
164-
model_args[k] = checkpoint_model_args[k]
165-
# create the model
166-
gptconf = GPTConfig(**model_args)
167-
model = GPT(gptconf)
168-
state_dict = checkpoint['model']
169-
# fix the keys of the state dictionary :(
170-
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
171-
unwanted_prefix = '_orig_mod.'
172-
for k,v in list(state_dict.items()):
173-
if k.startswith(unwanted_prefix):
174-
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
175-
model.load_state_dict(state_dict)
176-
iter_num = checkpoint['iter_num']
177-
best_val_loss = checkpoint['best_val_loss']
178156
elif init_from.startswith('gpt2'):
179157
raise DeprecationWarning('init from is deprecated')
180-
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
181-
# initialize from OpenAI GPT-2 weights
182-
override_args = dict(dropout=dropout)
183-
model = GPT.from_pretrained(init_from, override_args)
184-
# read off the created config params, so we can store them into checkpoint correctly
185-
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
186-
model_args[k] = getattr(model.config, k)
187158

188159
# crop down the model block size if desired, using model surgery
189160
if block_size < model.config.block_size:

0 commit comments

Comments
 (0)