Skip to content

Commit 7861f32

Browse files
authored
Simplify logging (ai-safety-foundation#65)
Just use one main tqdm loop as nested loops aren't well supported with VSCode.
1 parent 6a69e67 commit 7861f32

File tree

3 files changed

+53
-72
lines changed

3 files changed

+53
-72
lines changed

sparse_autoencoder/train/generate_activations.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from jaxtyping import Int
66
import torch
77
from torch import Tensor
8-
from tqdm.auto import tqdm
98
from transformer_lens import HookedTransformer
109

1110
from sparse_autoencoder.activation_store.base_store import (
@@ -57,6 +56,9 @@ def generate_activations(
5756
than strict limit.
5857
device: Device to run the model on.
5958
"""
59+
# Set model to evaluation (inference) mode
60+
model.eval()
61+
6062
if isinstance(device, torch.device):
6163
model.to(device, print_details=False)
6264

@@ -70,17 +72,10 @@ def generate_activations(
7072
total: int = num_items - num_items % activations_per_batch
7173

7274
# Loop through the dataloader until the store reaches the desired size
73-
with torch.no_grad(), tqdm(
74-
desc="Generate Activations",
75-
total=total - total % activations_per_batch,
76-
colour="green",
77-
leave=False,
78-
dynamic_ncols=True,
79-
) as progress_bar:
75+
with torch.no_grad():
8076
for batch in source_data:
8177
if len(store) + activations_per_batch > total:
8278
break
8379

8480
input_ids: Int[Tensor, "batch pos"] = batch["input_ids"].to(device)
8581
model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore (TLens is typed incorrectly)
86-
progress_bar.update(activations_per_batch)

sparse_autoencoder/train/pipeline.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,23 @@ def pipeline( # noqa: PLR0913
114114
total_steps: int = 0
115115
activations_since_resampling: int = 0
116116
neuron_activity: Int[Tensor, " learned_features"] = torch.zeros(
117-
autoencoder.n_learned_features, dtype=torch.int32, device=device
117+
autoencoder.n_learned_features,
118+
dtype=torch.int32,
119+
device=device,
118120
)
119121
total_activations: int = 0
120-
generate_train_iterations: int = 0
121122

122123
# Run loop until source data is exhausted:
123124
with logging_redirect_tqdm(), tqdm(
124125
desc="Total activations trained on",
125126
dynamic_ncols=True,
126-
colour="blue",
127127
total=max_activations,
128-
postfix={"Generate/train iterations": 0},
128+
postfix={"Current mode": "initializing"},
129129
) as progress_bar:
130130
while total_activations < max_activations:
131-
activation_store.empty() # In case it was filled by a different run
132-
133131
# Add activations to the store
132+
activation_store.empty() # In case it was filled by a different run
133+
progress_bar.set_postfix({"Current mode": "generating"})
134134
generate_activations(
135135
src_model,
136136
src_model_activation_layer,
@@ -150,6 +150,7 @@ def pipeline( # noqa: PLR0913
150150
activation_store.shuffle()
151151

152152
# Train the autoencoder
153+
progress_bar.set_postfix({"Current mode": "training"})
153154
train_steps, learned_activations_fired_count = train_autoencoder(
154155
activation_store=activation_store,
155156
autoencoder=autoencoder,
@@ -169,6 +170,7 @@ def pipeline( # noqa: PLR0913
169170

170171
# Resample neurons if required
171172
if activations_since_resampling >= resample_frequency:
173+
progress_bar.set_postfix({"Current mode": "resampling"})
172174
activations_since_resampling = 0
173175
resample_dead_neurons(
174176
neuron_activity=neuron_activity,
@@ -180,7 +182,3 @@ def pipeline( # noqa: PLR0913
180182
optimizer.reset_state_all_parameters()
181183

182184
activation_store.empty()
183-
184-
progress_bar.update(1)
185-
generate_train_iterations += 1
186-
progress_bar.set_postfix({"Generate/train iterations": generate_train_iterations})
Lines changed: 41 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Training Pipeline."""
22
from jaxtyping import Float, Int
33
import torch
4-
from torch import Tensor, device, set_grad_enabled
4+
from torch import Tensor, device
55
from torch.optim import Optimizer
66
from torch.utils.data import DataLoader
7-
from tqdm.auto import tqdm
87
import wandb
98

109
from sparse_autoencoder.activation_store.base_store import ActivationStore
@@ -46,63 +45,52 @@ def train_autoencoder(
4645
batch_size=sweep_parameters.batch_size,
4746
)
4847

49-
n_dataset_items: int = len(activation_store)
50-
batch_size: int = sweep_parameters.batch_size
51-
5248
learned_activations_fired_count: Int[Tensor, " learned_feature"] = torch.zeros(
5349
autoencoder.n_learned_features, dtype=torch.int32, device=device
5450
)
5551

56-
step = 0
57-
with set_grad_enabled(True), tqdm( # noqa: FBT003
58-
desc="Train Autoencoder",
59-
total=n_dataset_items,
60-
colour="green",
61-
leave=False,
62-
dynamic_ncols=True,
63-
) as progress_bar:
64-
for step, batch in enumerate(activations_dataloader):
65-
# Zero the gradients
66-
optimizer.zero_grad()
67-
68-
# Move the batch to the device (in place)
69-
batch = batch.to(device) # noqa: PLW2901
70-
71-
# Forward pass
72-
learned_activations, reconstructed_activations = autoencoder(batch)
73-
74-
# Get metrics
75-
reconstruction_loss_mse: Float[Tensor, " item"] = reconstruction_loss(
76-
batch,
77-
reconstructed_activations,
78-
)
79-
l1_loss_learned_activations: Float[Tensor, " item"] = l1_loss(learned_activations)
80-
total_loss: Float[Tensor, " item"] = sae_training_loss(
81-
reconstruction_loss_mse,
82-
l1_loss_learned_activations,
83-
sweep_parameters.l1_coefficient,
84-
)
85-
86-
# Store count of how many neurons have fired
52+
step: int = 0 # Initialize step
53+
for step, store_batch in enumerate(activations_dataloader):
54+
# Zero the gradients
55+
optimizer.zero_grad()
56+
57+
# Move the batch to the device (in place)
58+
batch = store_batch.detach().to(device)
59+
60+
# Forward pass
61+
learned_activations, reconstructed_activations = autoencoder(batch)
62+
63+
# Get metrics
64+
reconstruction_loss_mse: Float[Tensor, " item"] = reconstruction_loss(
65+
batch,
66+
reconstructed_activations,
67+
)
68+
l1_loss_learned_activations: Float[Tensor, " item"] = l1_loss(learned_activations)
69+
total_loss: Float[Tensor, " item"] = sae_training_loss(
70+
reconstruction_loss_mse,
71+
l1_loss_learned_activations,
72+
sweep_parameters.l1_coefficient,
73+
)
74+
75+
# Store count of how many neurons have fired
76+
with torch.no_grad():
8777
fired = learned_activations > 0
8878
learned_activations_fired_count.add_(fired.sum(dim=0))
8979

90-
# Backwards pass
91-
total_loss.mean().backward()
92-
93-
optimizer.step()
94-
95-
# Log
96-
if step % log_interval == 0 and wandb.run is not None:
97-
wandb.log(
98-
{
99-
"reconstruction_loss": reconstruction_loss_mse.mean().item(),
100-
"l1_loss": l1_loss_learned_activations.mean().item(),
101-
"loss": total_loss.mean().item(),
102-
},
103-
)
80+
# Backwards pass
81+
total_loss.mean().backward()
82+
optimizer.step()
83+
84+
# Log
85+
if step % log_interval == 0 and wandb.run is not None:
86+
wandb.log(
87+
{
88+
"reconstruction_loss": reconstruction_loss_mse.mean().item(),
89+
"l1_loss": l1_loss_learned_activations.mean().item(),
90+
"loss": total_loss.mean().item(),
91+
},
92+
)
10493

105-
progress_bar.update(batch_size)
94+
current_step = previous_steps + step + 1
10695

107-
current_step = previous_steps + step + 1
108-
return current_step, learned_activations_fired_count
96+
return current_step, learned_activations_fired_count

0 commit comments

Comments
 (0)