Skip to content

Commit cb4ae42

Browse files
committed
A bit of cleanup, file naming clarity
1 parent b37d2b5 commit cb4ae42

File tree

4 files changed

+12
-27
lines changed

4 files changed

+12
-27
lines changed
File renamed without changes.

autoencoder/resource_loader.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def load_text_data(self):
5858
def load_transformer_model(self):
5959
"""Loads the GPT model with pre-trained weights."""
6060
ckpt_path = os.path.join(self.base_dir, 'transformer', self.gpt_ckpt_dir, 'ckpt.pt')
61-
checkpoint = torch.load(ckpt_path, map_location=self.device)
61+
checkpoint = torch.load(ckpt_path, map_location=self.device, weights_only=False)
6262
gpt_conf = GPTConfig(**checkpoint['model_args'])
6363
transformer = HookedGPT(gpt_conf)
6464
state_dict = checkpoint['model']
@@ -100,7 +100,7 @@ def load_autoencoder_model(self):
100100
state_dict = autoencoder_ckpt['autoencoder']
101101
n_features, n_ffwd = state_dict['encoder.weight'].shape # H, F
102102
l1_coeff = autoencoder_ckpt['config']['l1_coeff']
103-
from autoencoder import AutoEncoder
103+
from autoencoder_architecture import AutoEncoder
104104

105105
autoencoder = AutoEncoder(n_ffwd, n_features, lam=l1_coeff).to(self.device)
106106
autoencoder.load_state_dict(state_dict)
@@ -114,7 +114,7 @@ def get_text_batch(self, num_contexts):
114114
Y = torch.stack([torch.from_numpy(self.text_data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix])
115115
return X.to(device=self.device), Y.to(device=self.device)
116116

117-
def get_autoencoder_data_batch(self, step, batch_size=8192):
117+
def get_autoencoder_data_batch(self, step, batch_size: int):
118118
"""
119119
Retrieves a batch of autoencoder data based on the step and batch size.
120120
It loads the next data partition if the batch exceeds the current partition.
@@ -141,14 +141,12 @@ def get_autoencoder_data_batch(self, step, batch_size=8192):
141141
return batch.to(self.device)
142142

143143
def load_next_autoencoder_partition(self, partition_id):
144-
"""
145-
Loads the specified partition of the autoencoder data.
146-
"""
144+
"""Loads the specified partition of the autoencoder data."""
147145
file_path = os.path.join(self.autoencoder_data_dir, f'{partition_id}.pt')
148-
self.autoencoder_data = torch.load(file_path)
146+
self.autoencoder_data = torch.load(file_path, weights_only=False)
149147
return self.autoencoder_data
150148

151-
def select_resampling_data(self, size=819200):
149+
def select_resampling_data(self, size: int):
152150
"""
153151
Selects a subset of autoencoder data for resampling, distributed evenly across partitions.
154152
"""
Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
Train a Sparse AutoEncoder model
33
44
Run on a macbook on a Shakespeare dataset as
5-
python train.py --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32 --eval_iters=1 --eval_batch_size=16 --batch_size=128 --device=cpu --eval_interval=100 --n_features=1024 --resampling_interval=150 --wandb_log=True
5+
python train.py --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32 --eval_iters=1 --eval_batch_size=16 --batch_size=128 --device=cpu --eval_interval=100 --n_features=1024 --resampling_interval=150
66
"""
77

88
import os
99
import torch
1010
import numpy as np
1111
import time
12-
from autoencoder import AutoEncoder
12+
from autoencoder_architecture import AutoEncoder
1313
from resource_loader import ResourceLoader
1414
from utils.plotting_utils import make_density_histogram
1515

@@ -18,7 +18,7 @@
1818
dataset = 'openwebtext'
1919
gpt_ckpt_dir = 'out'
2020
# training
21-
n_features = 4096
21+
n_features = 4096 # aka n_latents
2222
batch_size = 8192 # batch size for autoencoder training
2323
l1_coeff = 3e-3
2424
learning_rate = 3e-4
@@ -33,8 +33,6 @@
3333
save_checkpoint = True # whether to save model, optimizer, etc or not
3434
save_interval = 10000 # number of training steps after which a checkpoint will be saved
3535
out_dir = 'out' # directory containing trained autoencoder model weights
36-
# wandb logging
37-
wandb_log = True
3836
# system
3937
device = 'cuda'
4038
# reproducibility
@@ -57,20 +55,15 @@
5755

5856
gpt = resourceloader.transformer # TODO: either it should be called transformer or gpt
5957
autoencoder = AutoEncoder(
60-
n_inputs=(4 * resourceloader.transformer.config.n_embd), # ?? why 4x?
58+
n_inputs=(4 * resourceloader.transformer.config.n_embd),
6159
n_latents=n_features,
6260
lam=l1_coeff,
6361
resampling_interval=resampling_interval
6462
).to(device)
6563
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate)
6664

6765
## prepare for logging and saving checkpoints
68-
run_name = f'{time.time():.2f}'
69-
if wandb_log:
70-
raise DeprecationWarning('wandb is deprecated')
71-
import wandb
72-
73-
wandb.init(project=f'sparse-autoencoder-{dataset}', name=run_name, config=config)
66+
run_name = time.strftime('%Y-%m-%d-%H%M')
7467

7568
if save_checkpoint:
7669
ckpt_path = os.path.join(out_dir, dataset, run_name)
@@ -181,9 +174,6 @@
181174
'feature_density/num_alive_neurons': len(log_feat_acts_density),
182175
}
183176
)
184-
if wandb_log:
185-
log_dict.update({'feature_density/feature_density_histograms': wandb.Image(feat_density_historgram)})
186-
wandb.log(log_dict)
187177

188178
autoencoder.train()
189179
print(f'Exiting evaluation mode at step = {step}')
@@ -198,7 +188,4 @@
198188
'feature_activation_counts': feat_acts_count, # may be used later to identify alive vs dead neurons
199189
}
200190
print(f"saving checkpoint to {ckpt_path} at training step = {step}")
201-
torch.save(checkpoint, os.path.join(ckpt_path, 'ckpt.pt'))
202-
203-
if wandb_log:
204-
wandb.finish()
191+
torch.save(checkpoint, os.path.join(ckpt_path, 'ckpt.pt'))

0 commit comments

Comments
 (0)