Skip to content

Conversation

@fel-thomas
Copy link
Collaborator

Tied/Untied Mechanism Weight Encoder

Weight tying aligns encoder and decoder at training start, leading to faster convergence and reduced parameters.
The goal was to have a simple API: call .tied() or .untied() on any SAE in Overcomplete and you're done.

Implementation

Introduces TieableEncoder - a lightweight linear encoder that either:

  • Uses dictionary transpose D^T when tied
  • Maintains independent weights when untied

Works seamlessly with all SAE variants (TopK, Jump, Batch, MP, OMP, etc.)

Usage

sae = SAE(input_shape=768, nb_concepts=2048)
sae.tied()  # encoder now uses D^T

# ... train with tied weights ...

sae.untied(copy_from_dictionary=True)  # switch to independent weights

# ... continue training ...

~50% fewer parameters when tied, natural weight initialization when transitioning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants