Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion overcomplete/optimization/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from tqdm import tqdm
import torch
from scipy.optimize import nnls as scipy_nnls
from sklearn.decomposition._nmf import _initialize_nmf
from scipy.optimize import nnls as scipy_nnls


from .base import BaseOptimDictionaryLearning
from .utils import matrix_nnls, stopping_criterion, _assert_shapes
Expand Down
51 changes: 51 additions & 0 deletions overcomplete/sae/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..base import BaseDictionaryLearning
from .dictionary import DictionaryLayer
from .factory import EncoderFactory
from .modules import TieableEncoder


class SAE(BaseDictionaryLearning):
Expand Down Expand Up @@ -172,3 +173,53 @@ def fit(self, x):
"""
raise NotImplementedError('SAE does not support fit method. You have to train the model \
using a custom training loop.')

def tied(self, bias=False):
"""
Tie encoder weights to dictionary (use D^T as encoder).

Parameters
----------
bias : bool, optional
Whether to include bias in encoder, by default False.

Returns
-------
self
Returns self for method chaining.
"""
self.encoder = TieableEncoder(
in_dimensions=self.dictionary.in_dimensions,
nb_concepts=self.nb_concepts,
bias=bias,
tied_to=self.dictionary,
device=self.device
)
return self

def untied(self, bias=False, copy_from_dictionary=True):
"""
Create a new encoder with weight from the current dictionary (or random init).

Parameters
----------
bias : bool, optional
Whether to include bias in encoder, by default False.
copy_from_dictionary : bool, optional
If True, initialize encoder with current dictionary weights, by default True.

Returns
-------
self
Returns self for method chaining.
"""
weight_init = self.get_dictionary().clone().detach() if copy_from_dictionary else None
self.encoder = TieableEncoder(
in_dimensions=self.dictionary.in_dimensions,
nb_concepts=self.nb_concepts,
bias=bias,
tied_to=None,
weight_init=weight_init,
device=self.device
)
return self
1 change: 0 additions & 1 deletion overcomplete/sae/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""

import torch
import matplotlib.pyplot as plt


def rectangle_kernel(x, bandwith):
Expand Down
72 changes: 72 additions & 0 deletions overcomplete/sae/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Collections of torch modules for the encoding of SAE.
"""

import torch
from torch import nn
from einops import rearrange

Expand Down Expand Up @@ -463,3 +464,74 @@ def forward(self, x):
z = self.final_activation(pre_z)

return pre_z, z


class TieableEncoder(nn.Module):
"""
Linear encoder that can be tied to a dictionary or use independent weights.

Parameters
----------
in_dimensions : int
Input dimensionality.
nb_concepts : int
Number of latent dimensions.
bias : bool, optional
Whether to include bias, by default True.
tied_to : DictionaryLayer, optional
If provided, uses D^T (tied weights). If None, uses independent weights, by default None.
weight_init : torch.Tensor, optional
Initial weights for untied mode, by default None (uses Xavier initialization).
device : str, optional
Device for parameters, by default 'cpu'.
"""

def __init__(self, in_dimensions, nb_concepts, bias=False,
tied_to=None, weight_init=None, device='cpu'):
super().__init__()
self.tied_to = tied_to

if tied_to is None:
# untied: create own weights
self.weight = nn.Parameter(torch.empty(nb_concepts, in_dimensions, device=device))
if weight_init is not None:
self.weight.data.copy_(weight_init)
else:
nn.init.xavier_uniform_(self.weight)
else:
# tied weights: we use the dictionary transpose as encoder
# no weights needed
self.register_parameter('weight', None)

if bias:
self.bias = nn.Parameter(torch.zeros(nb_concepts, device=device))
else:
self.register_parameter('bias', None)

def forward(self, x):
"""
Encode input.

Parameters
----------
x : torch.Tensor
Input of shape (batch_size, in_dimensions).

Returns
-------
z_pre : torch.Tensor
Pre-activation codes.
z : torch.Tensor
Activated codes (ReLU applied).
"""
if self.tied_to is not None:
z_pre = x @ self.tied_to.get_dictionary().T
else:
# untied mode: use own weights
z_pre = x @ self.weight.T

if self.bias is not None:
z_pre = z_pre + self.bias

z = torch.relu(z_pre)
return z_pre, z
4 changes: 2 additions & 2 deletions overcomplete/sae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def _log_metrics(monitoring, logs, model, z, loss, optimizer):

if monitoring > 1:
# store directly some z values
# and the params / gradients norms
logs['z'].append(z.detach()[::10])
# if needed you can even store z statistics here,
# e.g. logs['z'].append(z.detach()[::10])
logs['z_l2'].append(l2(z).item())

logs['dictionary_sparsity'].append(l0_eps(model.get_dictionary()).mean().item())
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ opencv-python = "*"
torch = "*"
torchvision = "*"
timm = "*"
scipy = "*"

# requirements dev
[tool.poetry.group.dev.dependencies]
Expand All @@ -29,6 +30,7 @@ pylint = "*"
bumpversion = "*"
mkdocs = "*"
mkdocs-material = "*"
numkdoc = "*"

# versioning
[tool.bumpversion]
Expand Down
100 changes: 100 additions & 0 deletions tests/sae/test_base_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import torch
from overcomplete.sae import SAE, DictionaryLayer, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE
from overcomplete.sae.modules import TieableEncoder

from ..utils import epsilon_equal

all_sae = [SAE, JumpSAE, TopKSAE, QSAE, BatchTopKSAE, MpSAE, OMPSAE]

Expand Down Expand Up @@ -57,3 +60,100 @@ def test_sae_device(sae_class):
# ensure dictionary is on the meta device
dictionary = model.get_dictionary()
assert dictionary.device.type == 'meta'


def test_tieable_encoder_basic():
"""Test TieableEncoder can be created in both tied and untied modes."""
input_size = 10
nb_concepts = 5

# Create a dummy dictionary layer
dictionary = DictionaryLayer(input_size, nb_concepts)

# Test untied mode
encoder_untied = TieableEncoder(input_size, nb_concepts, tied_to=None)
assert encoder_untied.weight is not None
assert encoder_untied.tied_to is None

# Test tied mode
encoder_tied = TieableEncoder(input_size, nb_concepts, tied_to=dictionary)
assert encoder_tied.weight is None
assert encoder_tied.tied_to is dictionary


def test_tieable_encoder_forward():
"""Test TieableEncoder forward pass in both modes."""
input_size = 10
nb_concepts = 5
batch_size = 3

dictionary = DictionaryLayer(input_size, nb_concepts)
x = torch.randn(batch_size, input_size)

# Test untied forward
encoder_untied = TieableEncoder(input_size, nb_concepts, tied_to=None)
z_pre, z = encoder_untied(x)
assert z_pre.shape == (batch_size, nb_concepts)
assert z.shape == (batch_size, nb_concepts)
assert (z >= 0).all() # ReLU activation

# Test tied forward
encoder_tied = TieableEncoder(input_size, nb_concepts, tied_to=dictionary)
z_pre, z = encoder_tied(x)
assert z_pre.shape == (batch_size, nb_concepts)
assert z.shape == (batch_size, nb_concepts)
assert (z >= 0).all()


@pytest.mark.parametrize("sae_class", all_sae)
def test_sae_tied_untied(sae_class):
"""Test that SAE can switch between tied and untied modes."""
input_size = 10
nb_concepts = 5

model = sae_class(input_size, nb_concepts)

# Tie weights
model.tied()
assert isinstance(model.encoder, TieableEncoder)
assert model.encoder.tied_to is not None

# Untie weights
model.untied()
assert isinstance(model.encoder, TieableEncoder)
assert model.encoder.tied_to is None


@pytest.mark.parametrize("sae_class", all_sae)
def test_sae_tied_forward(sae_class):
"""Test that tied SAE produces valid outputs."""
input_size = 10
nb_concepts = 5

model = sae_class(input_size, nb_concepts)
model.tied()

x = torch.randn(3, input_size)
z_pre, z, x_hat = model(x)

assert z.shape == (3, nb_concepts)
assert x_hat.shape == (3, input_size)


@pytest.mark.parametrize("sae_class", all_sae)
def test_sae_untied_copy_weights(sae_class):
"""Test that untied with copy_from_dictionary copies weights correctly."""
input_size = 10
nb_concepts = 5

model = sae_class(input_size, nb_concepts)
model.tied()

# Get dictionary before untying
dict_before = model.get_dictionary().clone()

# Untie and copy
model.untied(copy_from_dictionary=True)

# Check that encoder weights match dictionary
assert epsilon_equal(model.encoder.weight, dict_before)
44 changes: 44 additions & 0 deletions tests/sae/test_sae_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from overcomplete.sae import DictionaryLayer, SAE, QSAE, TopKSAE, JumpSAE, BatchTopKSAE, MpSAE, OMPSAE
from overcomplete.sae.modules import TieableEncoder

from ..utils import epsilon_equal

Expand Down Expand Up @@ -270,3 +271,46 @@ def test_multiplier_optimizer_step():

# Check that the multiplier has been updated.
assert not torch.allclose(init_multiplier, layer.multiplier.detach(), atol=1e-6)


def test_tied_encoder_shares_dictionary_weights():
"""Test that tied encoder uses dictionary weights (not a copy)."""
input_size = 10
nb_concepts = 5

dictionary = DictionaryLayer(input_size, nb_concepts)
encoder = TieableEncoder(input_size, nb_concepts, tied_to=dictionary)

x = torch.randn(3, input_size)

# Forward pass
z_pre1, z1 = encoder(x)

# Modify dictionary weights
with torch.no_grad():
dictionary._weights.data *= 10.0
dictionary._weights.data += torch.randn_like(dictionary._weights)

# Forward pass again
z_pre2, z2 = encoder(x)

# Results should be different (weights are shared)
assert not epsilon_equal(z_pre1, z_pre2)


def test_tied_encoder_gradient_flow():
"""Test that gradients flow to dictionary through tied encoder."""
input_size = 10
nb_concepts = 5

dictionary = DictionaryLayer(input_size, nb_concepts)
encoder = TieableEncoder(input_size, nb_concepts, tied_to=dictionary)

x = torch.randn(3, input_size, requires_grad=True)
z_pre, z = encoder(x)

loss = z.sum()
loss.backward()

# Dictionary should have gradients
assert dictionary._weights.grad is not None
Loading