File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
sparse_autoencoder/activation_resampler Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -358,7 +358,7 @@ def renormalize_and_scale(
358
358
# Calculate the average norm of the encoder weights for alive neurons.
359
359
detached_encoder_weight = encoder_weight .detach () # Don't track gradients
360
360
alive_encoder_weights : Float [
361
- Tensor , Axis .names (Axis .LEARNT_FEATURE , Axis .INPUT_OUTPUT_FEATURE )
361
+ Tensor , Axis .names (Axis .ALIVE_FEATURE , Axis .INPUT_OUTPUT_FEATURE )
362
362
] = detached_encoder_weight [alive_neuron_mask , :]
363
363
average_alive_norm : Float [Tensor , Axis .SINGLE_ITEM ] = alive_encoder_weights .norm (
364
364
dim = - 1
@@ -416,7 +416,7 @@ def resample_dead_neurons(
416
416
# vector for the dead autoencoder neuron.
417
417
renormalized_input : Float [
418
418
Tensor , Axis .names (Axis .DEAD_FEATURE , Axis .INPUT_OUTPUT_FEATURE )
419
- ] = torch .nn .functional .normalize (sampled_input , dim = 0 )
419
+ ] = torch .nn .functional .normalize (sampled_input , dim = - 1 )
420
420
dead_decoder_weight_updates = rearrange (
421
421
renormalized_input , "dead_neuron input_feature -> input_feature dead_neuron"
422
422
)
You can’t perform that action at this time.
0 commit comments