This project provides tools to train sparse autoencoders on activations from various single-cell foundation models (scFMs), enabling mechanistic interpretability of these models. By decomposing neural network activations into sparse, interpretable features, we can better understand what biological patterns and cell states these models have learned.
-
Multiple SAE Architectures: Support for various sparse autoencoder variants including:
- Standard SAE with L1 regularization
- Top-K SAE
- Batch Top-K SAE
- Matryoshka Batch Top-K SAE
- Gated SAE
- JumpReLU SAE
- P-Anneal and Gated-Anneal SAE
-
Multiple scFM Adapters: Compatible with leading single-cell foundation models:
- scGPT
- scFoundation
- Geneformer
-
Comprehensive Analysis Tools:
- Label scoring (cell type, batch associations)
- Gene scoring and enrichment analysis
- Expression pattern analysis
- Feature density scoring
- Feature steering and manipulation
-
Efficient Training Pipeline:
- Parallel data loading with activation buffering
- Multi-GPU support
- Hydra-based configuration management
- Weights & Biases integration for experiment tracking
Train a sparse autoencoder on a foundation model's activations:
python scripts/train_sae.pyThe training script uses Hydra for configuration. You can override parameters:
python scripts/train_sae.py \
scfm=scfoundation \
data=pbmc \
sae=batchtopk \
sae.target_layer=5 \
sae.dictionary_multiplier=0.66666Extract SAE features for downstream analysis:
python scripts/generate_features.py \
sae_checkpoint.experiment=layer_sweeps \
sae_checkpoint.timestamp=Jan29-10-00Run interpretability analysis on extracted features:
python scripts/analyze_features.py \
sae_checkpoint.experiment=layer_sweeps \
sae_checkpoint.timestamp=Jan29-10-00 \
analysis.run_label_scoring=true \
analysis.run_gene_scoring=trueManipulate model behavior using learned SAE features:
python scripts/steer_features.py \
sae_checkpoint.experiment=layer_sweeps \
sae_checkpoint.timestamp=Jan29-10-00The project uses Hydra for hierarchical configuration management. Configurations are organized in config/:
- config/train.yaml - Main training configuration
- config/scfm/ - Foundation model configurations
- config/sae/ - SAE architecture configurations
- config/data/ - Dataset configurations (PBMC, COVID, Census)
- config/buffer/ - Activation buffer configurations
Training:
sae.target_layer: Which model layer to extract activations fromsae.dictionary_multiplier: SAE hidden dimension as multiple of input dimensionsae.hyperparams.k: Sparsity parameter (for Top-K variants)sae.hyperparams.lr: Learning rateseed: Random seed for reproducibility
Data:
data.name: Dataset name (pbmc, covid, census)data.n_cells: Number of cells to usedata.preprocess.split: Train/test split fractiondata.preprocess.subset_hvg: Whether to subset to highly variable genes
sae-for-scFMs/
├── config/ # Hydra configuration files
│ ├── data/ # Dataset configs
│ ├── sae/ # SAE architecture configs
│ ├── scfm/ # Foundation model configs
│ └── buffer/ # Data buffer configs
├── scripts/ # Entry point scripts
│ ├── train_sae.py # Train sparse autoencoders
│ ├── generate_features.py # Extract SAE features
│ ├── analyze_features.py # Analyze feature interpretability
│ ├── steer_features.py # Feature steering experiments
│ └── benchmark_integration.py
├── sae4scfm/ # Main package
│ ├── core/ # Core utilities
│ │ ├── buffer.py # Activation buffering
│ │ ├── data_loader.py # Data loading
│ │ ├── evaluation.py # Evaluation metrics
│ │ ├── analysis.py # Feature analysis
│ │ ├── steering.py # Feature steering
│ │ └── io_utils.py # I/O utilities
│ ├── sae/ # SAE implementations
│ │ ├── standard.py # Standard SAE
│ │ ├── top_k.py # Top-K SAE
│ │ ├── batch_top_k.py # Batch Top-K SAE
│ │ ├── matryoshka_batch_top_k.py
│ │ ├── gdm.py # Gated SAE
│ │ ├── jumprelu.py # JumpReLU SAE
│ │ └── trainer.py # Base trainer
│ └── scfm/ # Foundation model adapters
│ ├── base.py # Abstract adapter interface
│ ├── scgpt/ # scGPT adapter
│ ├── scfoundation/ # scFoundation adapter
│ └── geneformer/ # Geneformer adapter
└── jobs/ # Job submission scripts
Each foundation model requires a specific adapter that implements the ModelAdapter interface defined in sae4scfm/scfm/base.py. Adapters handle:
- Model loading and initialization
- Data preprocessing for the specific model format
- Forward hook registration for activation extraction
- Model-specific embedding generation
Currently supported models:
- scGPT: Generative pre-trained transformer for single-cell RNA-seq
- scFoundation: Foundation model with performer architecture
- Geneformer: Transformer model trained on rank-value gene encodings
The framework supports multiple SAE architectures optimized for different use cases:
- Standard SAE: Classic autoencoder with L1 sparsity penalty
- Top-K SAE: Fixed sparsity using top-k activation selection
- Batch Top-K SAE: Batch-level top-k for improved feature diversity
- Matryoshka SAE: Nested feature learning at multiple scales
- Gated SAE: Gating mechanism for improved reconstruction
- JumpReLU SAE: Jump ReLU activation for sharper features
See sae4scfm/sae/ for implementations.
The framework provides comprehensive feature analysis tools:
- Label Scoring: Statistical association with cell types, batches, and other metadata
- Gene Scoring: Gene-level feature attribution and enrichment
- Expression Scoring: Relationship to gene expression patterns
- Density Scoring: Feature activation density across cells
- Gene Family Analysis: GSEA and gene set enrichment
Results are saved as structured CSV files with multi-level columns for easy downstream analysis.
Training runs are automatically tracked using:
- Weights & Biases (configurable, defaults to offline mode)
- Hydra output directories with timestamped runs
- Checkpoint saving for trained SAE models
- Metric logging for reconstruction quality and sparsity