Skip to content

Add save and load methods to the model #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ docs/content/reference

# Wandb
wandb/
artifacts/

# Scratch files
scratch.py
Expand Down
2 changes: 2 additions & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"autocast",
"autoencoder",
"autoencoders",
"autoencoding",
"autofix",
"capturable",
"categoricalwprobabilities",
Expand Down Expand Up @@ -76,6 +77,7 @@
"optim",
"penality",
"perp",
"pickleable",
"polysemantic",
"polysemantically",
"polysemanticity",
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ The library is designed to be modular. By default it takes the approach from [To
Monosemanticity: Decomposing Language Models With Dictionary Learning
](https://transformer-circuits.pub/2023/monosemantic-features/index.html), so you can pip install
the library and get started quickly. Then when you need to customise something, you can just extend
the abstract class for that component (e.g. you can extend `AbstractEncoder` if you want to
customise the encoder layer, and then easily drop it in the standard `SparseAutoencoder` model to
keep everything else as is. Every component is fully documented, so it's nice and easy to do this.
the class for that component (e.g. you can extend `SparseAutoencoder` if you want to customise the
model, and then drop it back into the training pipeline. Every component is fully documented, so
it's nice and easy to do this.

## Demo

Expand Down
7 changes: 5 additions & 2 deletions docs/content/flexible_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
" Pipeline,\n",
" PreTokenizedDataset,\n",
" SparseAutoencoder,\n",
" SparseAutoencoderConfig,\n",
")\n",
"import wandb\n",
"\n",
Expand Down Expand Up @@ -235,8 +236,10 @@
"source": [
"expansion_factor = hyperparameters[\"expansion_factor\"]\n",
"autoencoder = SparseAutoencoder(\n",
" n_input_features=autoencoder_input_dim, # size of the activations we are autoencoding\n",
" n_learned_features=int(autoencoder_input_dim * expansion_factor), # size of SAE\n",
" SparseAutoencoderConfig(\n",
" n_input_features=autoencoder_input_dim, # size of the activations we are autoencoding\n",
" n_learned_features=int(autoencoder_input_dim * expansion_factor), # size of SAE\n",
" )\n",
").to(device)\n",
"autoencoder"
]
Expand Down
3 changes: 2 additions & 1 deletion sparse_autoencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Sparse Autoencoder Library."""
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
from sparse_autoencoder.loss.abstract_loss import LossReductionType
from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
Expand Down Expand Up @@ -77,6 +77,7 @@
"SourceModelHyperparameters",
"SourceModelRuntimeHyperparameters",
"SparseAutoencoder",
"SparseAutoencoderConfig",
"sweep",
"SweepConfig",
"TensorActivationStore",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils import get_model_device
from sparse_autoencoder.train.utils.get_model_device import get_model_device


class LossInputActivationsTuple(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
from sparse_autoencoder.activation_store.base_store import ActivationStore
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig
from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
from sparse_autoencoder.loss.reducer import LossReducer
Expand Down Expand Up @@ -43,9 +43,11 @@ def full_activation_store() -> ActivationStore:
def autoencoder_model() -> SparseAutoencoder:
"""Create a dummy autoencoder model."""
return SparseAutoencoder(
n_components=DEFAULT_N_COMPONENTS,
n_input_features=DEFAULT_N_INPUT_FEATURES,
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
SparseAutoencoderConfig(
n_input_features=DEFAULT_N_INPUT_FEATURES,
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
n_components=DEFAULT_N_COMPONENTS,
)
)


Expand Down Expand Up @@ -126,7 +128,7 @@ def test_more_items_than_in_store_error(
):
ActivationResampler(
resample_dataset_size=DEFAULT_N_ACTIVATIONS_STORE + 1,
n_learned_features=autoencoder_model.n_learned_features,
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
).compute_loss_and_get_activations(
store=full_activation_store,
autoencoder=autoencoder_model,
Expand Down Expand Up @@ -285,7 +287,7 @@ def test_no_changes_if_no_dead_neurons(
resample_interval=10,
n_components=DEFAULT_N_COMPONENTS,
n_activations_activity_collate=10,
n_learned_features=autoencoder_model.n_learned_features,
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
resample_dataset_size=100,
)
updates = resampler.step_resampler(
Expand Down Expand Up @@ -328,7 +330,7 @@ def test_updates_dead_neuron_parameters(
resample_interval=10,
n_activations_activity_collate=10,
n_components=DEFAULT_N_COMPONENTS,
n_learned_features=autoencoder_model.n_learned_features,
n_learned_features=DEFAULT_N_LEARNED_FEATURES,
resample_dataset_size=100,
)
parameter_updates = resampler.step_resampler(
Expand All @@ -343,7 +345,7 @@ def test_updates_dead_neuron_parameters(
# Check the updated ones have changed
for component_idx, neuron_idx in dead_neurons:
# Decoder
decoder_weights = current_parameters["decoder._weight"]
decoder_weights = current_parameters["decoder.weight"]
current_dead_neuron_weights = decoder_weights[component_idx, neuron_idx]
updated_dead_decoder_weights = parameter_updates[
component_idx
Expand All @@ -353,7 +355,7 @@ def test_updates_dead_neuron_parameters(
), "Dead decoder weights should have changed."

# Encoder
current_dead_encoder_weights = current_parameters["encoder._weight"][
current_dead_encoder_weights = current_parameters["encoder.weight"][
component_idx, neuron_idx
]
updated_dead_encoder_weights = parameter_updates[
Expand All @@ -363,7 +365,7 @@ def test_updates_dead_neuron_parameters(
current_dead_encoder_weights, updated_dead_encoder_weights
), "Dead encoder weights should have changed."

current_dead_encoder_bias = current_parameters["encoder._bias"][
current_dead_encoder_bias = current_parameters["encoder.bias"][
component_idx, neuron_idx
]
updated_dead_encoder_bias = parameter_updates[component_idx].dead_encoder_bias_updates
Expand Down
74 changes: 0 additions & 74 deletions sparse_autoencoder/autoencoder/abstract_autoencoder.py

This file was deleted.

39 changes: 12 additions & 27 deletions sparse_autoencoder/autoencoder/components/linear_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,18 @@ class LinearEncoder(Module):

_n_components: int | None

_weight: Float[
weight: Float[
Parameter,
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
]
"""Weight parameter internal state."""
"""Weight parameter.

_bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
"""Bias parameter internal state."""

@property
def weight(
self,
) -> Float[
Parameter,
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
]:
"""Weight parameter.

Each row in the weights matrix acts as a dictionary vector, representing a single basis
element in the learned activation space.
"""
return self._weight
Each row in the weights matrix acts as a dictionary vector, representing a single basis
element in the learned activation space.
"""

@property
def bias(self) -> Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
"""Bias parameter."""
return self._bias
bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
"""Bias parameter."""

@property
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
Expand Down Expand Up @@ -109,12 +94,12 @@ def __init__(
self._input_features = input_features
self._n_components = n_components

self._weight = Parameter(
self.weight = Parameter(
torch.empty(
shape_with_optional_dimensions(n_components, learnt_features, input_features),
)
)
self._bias = Parameter(
self.bias = Parameter(
torch.zeros(shape_with_optional_dimensions(n_components, learnt_features))
)
self.activation_function = ReLU()
Expand All @@ -125,12 +110,12 @@ def reset_parameters(self) -> None:
"""Initialize or reset the parameters."""
# Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and
# `nonlinerity` must be changed.
init.kaiming_uniform_(self._weight, nonlinearity="relu")
init.kaiming_uniform_(self.weight, nonlinearity="relu")

# Bias (approach from nn.Linear)
fan_in = self._weight.size(1)
fan_in = self.weight.size(1)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self._bias, -bound, bound)
init.uniform_(self.bias, -bound, bound)

def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn

from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.autoencoder.model import SparseAutoencoder, SparseAutoencoderConfig


class NeelAutoencoder(nn.Module):
Expand Down Expand Up @@ -66,8 +66,10 @@ def test_biases_initialised_same_way() -> None:

torch.random.manual_seed(0)
autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
SparseAutoencoderConfig(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)
)

torch.random.manual_seed(0)
Expand All @@ -91,8 +93,10 @@ def test_forward_pass_same_weights() -> None:
l1_coefficient: float = 0.01

autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
SparseAutoencoderConfig(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)
)
neel_autoencoder = NeelAutoencoder(
d_hidden=n_learned_features,
Expand Down Expand Up @@ -122,8 +126,10 @@ def test_unit_norm_weights() -> None:
l1_coefficient: float = 0.01

autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
SparseAutoencoderConfig(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)
)
neel_autoencoder = NeelAutoencoder(
d_hidden=n_learned_features,
Expand All @@ -135,7 +141,7 @@ def test_unit_norm_weights() -> None:

# Set the same decoder weights
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
autoencoder.decoder.weight.data = decoder_weights # type: ignore
neel_autoencoder.W_dec.data = decoder_weights.T

# Do a forward & backward pass so we have gradients
Expand Down Expand Up @@ -165,8 +171,10 @@ def test_unit_norm_weights_grad() -> None:
l1_coefficient: float = 0.01

autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
SparseAutoencoderConfig(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)
)
neel_autoencoder = NeelAutoencoder(
d_hidden=n_learned_features,
Expand All @@ -176,9 +184,9 @@ def test_unit_norm_weights_grad() -> None:

# Set the same decoder weights
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
autoencoder.decoder.weight.data = decoder_weights # type: ignore
neel_autoencoder.W_dec.data = decoder_weights.T
autoencoder.decoder._weight.grad = torch.zeros_like(autoencoder.decoder.weight) # noqa: SLF001 # type: ignore
autoencoder.decoder.weight.grad = torch.zeros_like(autoencoder.decoder.weight) # type: ignore
neel_autoencoder.W_dec.grad = torch.zeros_like(neel_autoencoder.W_dec)

# Set the same tied bias weights
Expand Down
Loading