Skip to content

Commit

Permalink
feat: Added tanh-relu activation fn and input noise options (#77)
Browse files Browse the repository at this point in the history
* Still need to pip-install from GitHub hufy implementation.

* Added support for `tanh_sae`.

* Added notebook for loading the `tanh_sae`

* tweaking config options to be more declarating / composable

* testing adding noise to SAE forward pass

* updating notebook

---------

Co-authored-by: David Chanin <chanindav@gmail.com>
  • Loading branch information
HuFY-dev and chanind authored Apr 21, 2024
1 parent 6d45b33 commit 551e94d
Show file tree
Hide file tree
Showing 8 changed files with 698 additions and 13 deletions.
18 changes: 18 additions & 0 deletions sae_lens/training/activation_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Callable

import torch


def get_activation_fn(activation_fn: str) -> Callable[[torch.Tensor], torch.Tensor]:
if activation_fn == "relu":
return torch.nn.ReLU()
elif activation_fn == "tanh-relu":
return tanh_relu
else:
raise ValueError(f"Unknown activation function: {activation_fn}")


def tanh_relu(input: torch.Tensor) -> torch.Tensor:
input = torch.relu(input)
input = torch.tanh(input)
return input
3 changes: 3 additions & 0 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class LanguageModelSAERunnerConfig:
d_sae: Optional[int] = None
b_dec_init_method: str = "geometric_median"
expansion_factor: int | list[int] = 4
activation_fn: str = "relu" # relu, tanh-relu
normalize_sae_decoder: bool = True
noise_scale: float = 0.0
from_pretrained_path: Optional[str] = None
apply_b_dec_to_input: bool = True
decoder_orthogonal_init: bool = False
Expand Down
22 changes: 17 additions & 5 deletions sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import os
import pickle
from typing import NamedTuple, Optional
from typing import Callable, NamedTuple, Optional

import einops
import torch
Expand All @@ -15,6 +15,7 @@
from torch import nn
from transformer_lens.hook_points import HookedRootModule, HookPoint

from sae_lens.training.activation_functions import get_activation_fn
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.utils import BackwardsCompatiblePickleClass

Expand All @@ -35,9 +36,12 @@ class SparseAutoencoder(HookedRootModule):
lp_norm: float
d_sae: int
use_ghost_grads: bool
normalize_sae_decoder: bool
hook_point_layer: int
dtype: torch.dtype
device: str | torch.device
noise_scale: float
activation_fn: Callable[[torch.Tensor], torch.Tensor]

def __init__(
self,
Expand Down Expand Up @@ -69,7 +73,10 @@ def __init__(
self.dtype = cfg.dtype
self.device = cfg.device
self.use_ghost_grads = cfg.use_ghost_grads
self.normalize_sae_decoder = cfg.normalize_sae_decoder
self.hook_point_layer = cfg.hook_point_layer
self.noise_scale = cfg.noise_scale
self.activation_fn = get_activation_fn(cfg.activation_fn)

# NOTE: if using resampling neurons method, you must ensure that we initialise the weights in the order W_enc, b_enc, W_dec, b_dec
self.W_enc = nn.Parameter(
Expand All @@ -90,9 +97,10 @@ def __init__(
if self.cfg.decoder_orthogonal_init:
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T

with torch.no_grad():
# Anthropic normalize this to have unit columns
self.set_decoder_norm_to_unit_norm()
if self.normalize_sae_decoder:
with torch.no_grad():
# Anthropic normalize this to have unit columns
self.set_decoder_norm_to_unit_norm()

self.b_dec = nn.Parameter(
torch.zeros(self.d_in, dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -125,7 +133,11 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)
)
+ self.b_enc
)
feature_acts = self.hook_hidden_post(torch.nn.functional.relu(hidden_pre))
noisy_hidden_pre = hidden_pre
if self.noise_scale > 0:
noise = torch.randn_like(hidden_pre) * self.noise_scale
noisy_hidden_pre = hidden_pre + noise
feature_acts = self.hook_hidden_post(self.activation_fn(noisy_hidden_pre))

sae_out = self.hook_sae_out(
einops.einsum(
Expand Down
9 changes: 6 additions & 3 deletions sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ def _train_step(

sparse_autoencoder.train()
# Make sure the W_dec is still zero-norm
sparse_autoencoder.set_decoder_norm_to_unit_norm()
if sparse_autoencoder.normalize_sae_decoder:
sparse_autoencoder.set_decoder_norm_to_unit_norm()

# log and then reset the feature sparsity every feature_sampling_window steps
if (n_training_steps + 1) % feature_sampling_window == 0:
Expand Down Expand Up @@ -432,7 +433,8 @@ def _train_step(

ctx.optimizer.zero_grad()
loss.backward()
sparse_autoencoder.remove_gradient_parallel_to_decoder_directions()
if sparse_autoencoder.normalize_sae_decoder:
sparse_autoencoder.remove_gradient_parallel_to_decoder_directions()
ctx.optimizer.step()
ctx.scheduler.step()

Expand Down Expand Up @@ -504,7 +506,8 @@ def _save_checkpoint(

ctx = train_contexts[name]
path = f"{checkpoint_path}/{name}"
sae.set_decoder_norm_to_unit_norm()
if sae.normalize_sae_decoder:
sae.set_decoder_norm_to_unit_norm()
sae.save_model(path)
log_feature_sparsities = {"sparsity": ctx.log_feature_sparsity}

Expand Down
6 changes: 4 additions & 2 deletions sae_lens/training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def train_toy_sae(
for _, batch in enumerate(pbar):
batch = next(dataloader)
# Make sure the W_dec is still zero-norm
sparse_autoencoder.set_decoder_norm_to_unit_norm()
if sparse_autoencoder.normalize_sae_decoder:
sparse_autoencoder.set_decoder_norm_to_unit_norm()

# Forward and Backward Passes
optimizer.zero_grad()
sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(batch)
loss.backward()
sparse_autoencoder.remove_gradient_parallel_to_decoder_directions()
if sparse_autoencoder.normalize_sae_decoder:
sparse_autoencoder.remove_gradient_parallel_to_decoder_directions()
optimizer.step()

n_training_tokens += batch_size
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/training/test_activation_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch

from sae_lens.training.activation_functions import get_activation_fn


def test_get_activation_fn_tanh_relu():
tanh_relu = get_activation_fn("tanh-relu")
assert tanh_relu(torch.tensor([-1.0, 0.0])).tolist() == [0.0, 0.0]
assert tanh_relu(torch.tensor(1e10)).item() == pytest.approx(1.0)


def test_get_activation_fn_relu():
relu = get_activation_fn("relu")
assert relu(torch.tensor([-1.0, 0.0])).tolist() == [0.0, 0.0]
assert relu(torch.tensor(999.9)).item() == pytest.approx(999.9)


def test_get_activation_fn_error_for_unknown_values():
with pytest.raises(ValueError):
get_activation_fn("unknown")
25 changes: 22 additions & 3 deletions tests/unit/training/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,32 @@ def test_per_item_mse_loss_with_norm_matches_original_implementation() -> None:
assert torch.allclose(orig_impl_res, sae_res, atol=1e-5)


def test_SparseAutoencoder_forward_can_add_noise_to_hidden_pre() -> None:
clean_cfg = build_sae_cfg(d_in=2, d_sae=4, noise_scale=0)
noisy_cfg = build_sae_cfg(d_in=2, d_sae=4, noise_scale=100)
clean_sae = SparseAutoencoder(clean_cfg)
noisy_sae = SparseAutoencoder(noisy_cfg)

input = torch.randn(3, 2)

clean_output1 = clean_sae.forward(input).sae_out
clean_output2 = clean_sae.forward(input).sae_out
noisy_output1 = noisy_sae.forward(input).sae_out
noisy_output2 = noisy_sae.forward(input).sae_out

# with no noise, the outputs should be identical
assert torch.allclose(clean_output1, clean_output2)
# noisy outputs should be different
assert not torch.allclose(noisy_output1, noisy_output2)
assert not torch.allclose(clean_output1, noisy_output1)


def test_SparseAutoencoder_remove_gradient_parallel_to_decoder_directions() -> None:
cfg = build_sae_cfg()
cfg = build_sae_cfg(normalize_sae_decoder=True)
sae = SparseAutoencoder(cfg)
orig_grad = torch.randn_like(sae.W_dec)
orig_W_dec = sae.W_dec.clone()
sae.W_dec.grad = orig_grad.clone()

sae.remove_gradient_parallel_to_decoder_directions()

# check that the gradient is orthogonal to the decoder directions
Expand Down Expand Up @@ -317,7 +336,7 @@ def test_SparseAutoencoder_get_name_returns_correct_name_from_cfg_vals() -> None


def test_SparseAutoencoder_set_decoder_norm_to_unit_norm() -> None:
cfg = build_sae_cfg()
cfg = build_sae_cfg(normalize_sae_decoder=True)
sae = SparseAutoencoder(cfg)
sae.W_dec.data = 20 * torch.randn_like(sae.W_dec)
sae.set_decoder_norm_to_unit_norm()
Expand Down
Loading

0 comments on commit 551e94d

Please sign in to comment.