|
1 | 1 | """Training Pipeline."""
|
2 | 2 | from jaxtyping import Float, Int
|
3 | 3 | import torch
|
4 |
| -from torch import Tensor, device, set_grad_enabled |
| 4 | +from torch import Tensor, device |
5 | 5 | from torch.optim import Optimizer
|
6 | 6 | from torch.utils.data import DataLoader
|
7 |
| -from tqdm.auto import tqdm |
8 | 7 | import wandb
|
9 | 8 |
|
10 | 9 | from sparse_autoencoder.activation_store.base_store import ActivationStore
|
@@ -46,63 +45,52 @@ def train_autoencoder(
|
46 | 45 | batch_size=sweep_parameters.batch_size,
|
47 | 46 | )
|
48 | 47 |
|
49 |
| - n_dataset_items: int = len(activation_store) |
50 |
| - batch_size: int = sweep_parameters.batch_size |
51 |
| - |
52 | 48 | learned_activations_fired_count: Int[Tensor, " learned_feature"] = torch.zeros(
|
53 | 49 | autoencoder.n_learned_features, dtype=torch.int32, device=device
|
54 | 50 | )
|
55 | 51 |
|
56 |
| - step = 0 |
57 |
| - with set_grad_enabled(True), tqdm( # noqa: FBT003 |
58 |
| - desc="Train Autoencoder", |
59 |
| - total=n_dataset_items, |
60 |
| - colour="green", |
61 |
| - leave=False, |
62 |
| - dynamic_ncols=True, |
63 |
| - ) as progress_bar: |
64 |
| - for step, batch in enumerate(activations_dataloader): |
65 |
| - # Zero the gradients |
66 |
| - optimizer.zero_grad() |
67 |
| - |
68 |
| - # Move the batch to the device (in place) |
69 |
| - batch = batch.to(device) # noqa: PLW2901 |
70 |
| - |
71 |
| - # Forward pass |
72 |
| - learned_activations, reconstructed_activations = autoencoder(batch) |
73 |
| - |
74 |
| - # Get metrics |
75 |
| - reconstruction_loss_mse: Float[Tensor, " item"] = reconstruction_loss( |
76 |
| - batch, |
77 |
| - reconstructed_activations, |
78 |
| - ) |
79 |
| - l1_loss_learned_activations: Float[Tensor, " item"] = l1_loss(learned_activations) |
80 |
| - total_loss: Float[Tensor, " item"] = sae_training_loss( |
81 |
| - reconstruction_loss_mse, |
82 |
| - l1_loss_learned_activations, |
83 |
| - sweep_parameters.l1_coefficient, |
84 |
| - ) |
85 |
| - |
86 |
| - # Store count of how many neurons have fired |
| 52 | + step: int = 0 # Initialize step |
| 53 | + for step, store_batch in enumerate(activations_dataloader): |
| 54 | + # Zero the gradients |
| 55 | + optimizer.zero_grad() |
| 56 | + |
| 57 | + # Move the batch to the device (in place) |
| 58 | + batch = store_batch.detach().to(device) |
| 59 | + |
| 60 | + # Forward pass |
| 61 | + learned_activations, reconstructed_activations = autoencoder(batch) |
| 62 | + |
| 63 | + # Get metrics |
| 64 | + reconstruction_loss_mse: Float[Tensor, " item"] = reconstruction_loss( |
| 65 | + batch, |
| 66 | + reconstructed_activations, |
| 67 | + ) |
| 68 | + l1_loss_learned_activations: Float[Tensor, " item"] = l1_loss(learned_activations) |
| 69 | + total_loss: Float[Tensor, " item"] = sae_training_loss( |
| 70 | + reconstruction_loss_mse, |
| 71 | + l1_loss_learned_activations, |
| 72 | + sweep_parameters.l1_coefficient, |
| 73 | + ) |
| 74 | + |
| 75 | + # Store count of how many neurons have fired |
| 76 | + with torch.no_grad(): |
87 | 77 | fired = learned_activations > 0
|
88 | 78 | learned_activations_fired_count.add_(fired.sum(dim=0))
|
89 | 79 |
|
90 |
| - # Backwards pass |
91 |
| - total_loss.mean().backward() |
92 |
| - |
93 |
| - optimizer.step() |
94 |
| - |
95 |
| - # Log |
96 |
| - if step % log_interval == 0 and wandb.run is not None: |
97 |
| - wandb.log( |
98 |
| - { |
99 |
| - "reconstruction_loss": reconstruction_loss_mse.mean().item(), |
100 |
| - "l1_loss": l1_loss_learned_activations.mean().item(), |
101 |
| - "loss": total_loss.mean().item(), |
102 |
| - }, |
103 |
| - ) |
| 80 | + # Backwards pass |
| 81 | + total_loss.mean().backward() |
| 82 | + optimizer.step() |
| 83 | + |
| 84 | + # Log |
| 85 | + if step % log_interval == 0 and wandb.run is not None: |
| 86 | + wandb.log( |
| 87 | + { |
| 88 | + "reconstruction_loss": reconstruction_loss_mse.mean().item(), |
| 89 | + "l1_loss": l1_loss_learned_activations.mean().item(), |
| 90 | + "loss": total_loss.mean().item(), |
| 91 | + }, |
| 92 | + ) |
104 | 93 |
|
105 |
| - progress_bar.update(batch_size) |
| 94 | + current_step = previous_steps + step + 1 |
106 | 95 |
|
107 |
| - current_step = previous_steps + step + 1 |
108 |
| - return current_step, learned_activations_fired_count |
| 96 | + return current_step, learned_activations_fired_count |
0 commit comments