Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Apr 15, 2024
1 parent c41774e commit c359c27
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sae_lens/analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_config_classes import (
ActsHistogramConfig,
Column,
Expand All @@ -23,7 +24,6 @@
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader


Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import pandas as pd
import torch
import wandb
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

Expand Down
1 change: 1 addition & 0 deletions sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, cast

import wandb

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader

Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import einops
import torch

import wandb

from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.toy_models import Config as ToyConfig
from sae_lens.training.toy_models import Model as ToyModel
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Any, cast

import torch
import wandb
from safetensors.torch import save_file
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.evals import run_evals
from sae_lens.training.geometric_median import compute_geometric_median
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, cast

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


Expand Down
45 changes: 44 additions & 1 deletion tests/unit/training/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,47 @@ def test_sparse_autoencoder_forward(sparse_autoencoder: SparseAutoencoder):
assert l1_loss.shape == ()
assert torch.allclose(loss, mse_loss + l1_loss)

expected_mse_loss = (torch.pow((sae_out - x.float()), 2)).mean()

assert torch.allclose(mse_loss, expected_mse_loss)
expected_l1_loss = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
assert torch.allclose(l1_loss, sparse_autoencoder.l1_coefficient * expected_l1_loss)

# check everything has the right dtype
assert sae_out.dtype == sparse_autoencoder.dtype
assert feature_acts.dtype == sparse_autoencoder.dtype
assert loss.dtype == sparse_autoencoder.dtype
assert mse_loss.dtype == sparse_autoencoder.dtype
assert l1_loss.dtype == sparse_autoencoder.dtype


def test_sparse_autoencoder_forward_with_mse_loss_norm(
sparse_autoencoder: SparseAutoencoder,
):
batch_size = 32
d_in = sparse_autoencoder.d_in
d_sae = sparse_autoencoder.d_sae
sparse_autoencoder.cfg.mse_loss_normalization = "dense_batch"

x = torch.randn(batch_size, d_in)
(
sae_out,
feature_acts,
loss,
mse_loss,
l1_loss,
_ghost_grad_loss,
) = sparse_autoencoder.forward(
x,
)

assert sae_out.shape == (batch_size, d_in)
assert feature_acts.shape == (batch_size, d_sae)
assert loss.shape == ()
assert mse_loss.shape == ()
assert l1_loss.shape == ()
assert torch.allclose(loss, mse_loss + l1_loss)

x_centred = x - x.mean(dim=0, keepdim=True)
expected_mse_loss = (
torch.pow((sae_out - x.float()), 2)
Expand Down Expand Up @@ -199,7 +240,9 @@ def test_per_item_mse_loss_with_norm_matches_original_implementation() -> None:
torch.pow((input - target.float()), 2)
/ (target_centered**2).sum(dim=-1, keepdim=True).sqrt()
)
sae_res = _per_item_mse_loss_with_target_norm(input, target)
sae_res = _per_item_mse_loss_with_target_norm(
input, target, mse_loss_normalization="dense_batch"
)
assert torch.allclose(orig_impl_res, sae_res, atol=1e-5)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/training/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import pytest
import torch
import wandb
from datasets import Dataset
from torch import Tensor
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.optim import get_scheduler
from sae_lens.training.sae_group import SparseAutoencoderDictionary
Expand Down

0 comments on commit c359c27

Please sign in to comment.