Skip to content

Commit

Permalink
fix MSE in ghost grad
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Feb 9, 2024
1 parent 01ccb92 commit 44f7988
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def forward(self, x, dead_neuron_mask=None):

# 1.
residual = x - sae_out
residual_centred = residual - residual.mean(dim=0, keepdim=True)
l2_norm_residual = torch.norm(residual, dim=-1)

# 2.
Expand All @@ -123,7 +124,7 @@ def forward(self, x, dead_neuron_mask=None):
# 3.
mse_loss_ghost_resid = (
torch.pow((ghost_out - residual.detach().float()), 2)
/ (residual.detach() ** 2).sum(dim=-1, keepdim=True).sqrt()
/ (residual_centred.detach() ** 2).sum(dim=-1, keepdim=True).sqrt()
)
mse_rescaling_factor = (mse_loss / (mse_loss_ghost_resid + 1e-6)).detach()
mse_loss_ghost_resid = mse_rescaling_factor * mse_loss_ghost_resid
Expand Down

0 comments on commit 44f7988

Please sign in to comment.