Skip to content

Commit 830adc8

Browse files
authored
Remove abstract resampler class (ai-safety-foundation#183)
1 parent fb53d5d commit 830adc8

File tree

4 files changed

+30
-92
lines changed

4 files changed

+30
-92
lines changed

sparse_autoencoder/activation_resampler/abstract_activation_resampler.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Activation resampler."""
2+
from dataclasses import dataclass
23
from typing import Annotated, NamedTuple
34

45
from einops import rearrange
@@ -8,10 +9,6 @@
89
from torch import Tensor
910
from torch.utils.data import DataLoader
1011

11-
from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
12-
AbstractActivationResampler,
13-
ParameterUpdateResults,
14-
)
1512
from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
1613
get_component_slice_tensor,
1714
)
@@ -23,6 +20,27 @@
2320
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
2421

2522

23+
@dataclass
24+
class ParameterUpdateResults:
25+
"""Parameter update results from resampling dead neurons."""
26+
27+
dead_neuron_indices: Int64[Tensor, Axis.LEARNT_FEATURE_IDX]
28+
"""Dead neuron indices."""
29+
30+
dead_encoder_weight_updates: Float[
31+
Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
32+
]
33+
"""Dead encoder weight updates."""
34+
35+
dead_encoder_bias_updates: Float[Tensor, Axis.DEAD_FEATURE]
36+
"""Dead encoder bias updates."""
37+
38+
dead_decoder_weight_updates: Float[
39+
Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)
40+
]
41+
"""Dead decoder weight updates."""
42+
43+
2644
class LossInputActivationsTuple(NamedTuple):
2745
"""Loss and corresponding input activations tuple."""
2846

@@ -32,7 +50,7 @@ class LossInputActivationsTuple(NamedTuple):
3250
]
3351

3452

35-
class ActivationResampler(AbstractActivationResampler):
53+
class ActivationResampler:
3654
"""Activation resampler.
3755
3856
Collates the number of times each neuron fires over a set number of learned activation vectors,
@@ -510,9 +528,7 @@ def resample_dead_neurons(
510528

511529
def step_resampler(
512530
self,
513-
batch_neuron_activity: Int64[
514-
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
515-
],
531+
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
516532
activation_store: ActivationStore,
517533
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
518534
loss_fn: AbstractLoss,

sparse_autoencoder/train/pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from transformer_lens import HookedTransformer
1616
import wandb
1717

18-
from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
19-
AbstractActivationResampler,
18+
from sparse_autoencoder.activation_resampler.activation_resampler import (
19+
ActivationResampler,
2020
ParameterUpdateResults,
2121
)
2222
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
@@ -48,7 +48,7 @@ class Pipeline:
4848
hyperparameters.
4949
"""
5050

51-
activation_resampler: AbstractActivationResampler | None
51+
activation_resampler: ActivationResampler | None
5252
"""Activation resampler to use."""
5353

5454
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder]
@@ -96,7 +96,7 @@ def n_components(self) -> int:
9696
@validate_call(config={"arbitrary_types_allowed": True})
9797
def __init__(
9898
self,
99-
activation_resampler: AbstractActivationResampler | None,
99+
activation_resampler: ActivationResampler | None,
100100
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
101101
cache_names: list[str],
102102
layer: NonNegativeInt,

sparse_autoencoder/train/tests/test_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
Pipeline,
1515
SparseAutoencoder,
1616
)
17-
from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
17+
from sparse_autoencoder.activation_resampler.activation_resampler import (
18+
ActivationResampler,
1819
ParameterUpdateResults,
1920
)
20-
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
2121
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
2222
from sparse_autoencoder.autoencoder.model import SparseAutoencoderConfig
2323
from sparse_autoencoder.metrics.abstract_metric import MetricResult

0 commit comments

Comments
 (0)