Skip to content

Commit 3a15732

Browse files
authored
Fix activation resampler normalization dimension (ai-safety-foundation#155)
1 parent 6562591 commit 3a15732

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def renormalize_and_scale(
358358
# Calculate the average norm of the encoder weights for alive neurons.
359359
detached_encoder_weight = encoder_weight.detach() # Don't track gradients
360360
alive_encoder_weights: Float[
361-
Tensor, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
361+
Tensor, Axis.names(Axis.ALIVE_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
362362
] = detached_encoder_weight[alive_neuron_mask, :]
363363
average_alive_norm: Float[Tensor, Axis.SINGLE_ITEM] = alive_encoder_weights.norm(
364364
dim=-1
@@ -416,7 +416,7 @@ def resample_dead_neurons(
416416
# vector for the dead autoencoder neuron.
417417
renormalized_input: Float[
418418
Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
419-
] = torch.nn.functional.normalize(sampled_input, dim=0)
419+
] = torch.nn.functional.normalize(sampled_input, dim=-1)
420420
dead_decoder_weight_updates = rearrange(
421421
renormalized_input, "dead_neuron input_feature -> input_feature dead_neuron"
422422
)

0 commit comments

Comments
 (0)