A PyTorch implementation of Sparse Selective Hyper-Connections for stable and efficient deep residual learning.
SHC replaces traditional residual connections with sparse mixtures of orthogonal routing matrices, providing:
- Bounded spectral norm (ρ ≤ 1): Guarantees training stability
- 16× faster routing: Via closed-form Cayley transform (vs. Sinkhorn iteration)
- 3.3× KV cache reduction: Through learned low-rank factorization
- O(L) inference: Optional SSM distillation for linear-time generation
# Clone repository
git clone https://github.com/rahvis/shc.git
cd shc
# Create virtual environment
python -m venv venv
source venv/bin/activate # Linux/Mac
# or: venv\Scripts\activate # Windows
# Install dependencies
pip install -r shc/requirements.txt
# Install package in development mode
pip install -e .from shc.models import SHCTransformer, get_config
# Create model with predefined config
config = get_config('500m') # Options: '500m', '1b', '3b', '7b'
model = SHCTransformer(config)
# Forward pass
import torch
input_ids = torch.randint(0, 32000, (2, 512))
logits = model(input_ids)
# Generate text
output = model.generate(
input_ids[:, :10], # prompt
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
)# Single GPU
python -m shc.scripts.train --model_size 500m --output_dir ./output
# Multi-GPU with DDP
torchrun --nproc_per_node=8 -m shc.scripts.train \
--model_size 3b \
--batch_size 32 \
--learning_rate 3e-4
# FSDP for 7B+ models (memory efficient)
torchrun --nproc_per_node=8 -m shc.scripts.train \
--model_size 7b \
--use_fsdp \
--mixed_precision bf16# Run benchmarks
python -m shc.scripts.evaluate \
--model_path ./output/final \
--benchmarks bbh gsm8k mmlu
# Efficiency profiling
python -m shc.scripts.evaluate \
--model_path ./output/final \
--profile \
--analyze_routingfrom shc.models import SHCTransformer, SSMStudent
from shc.training import DistillationTrainer, DistillationConfig
# Load trained teacher
teacher = SHCTransformer.from_pretrained('path/to/teacher')
# Create student matching teacher dimensions
student = SSMStudent.from_teacher_config(teacher.config)
# Distill
config = DistillationConfig(max_steps=10000)
trainer = DistillationTrainer(teacher, student, config, train_loader)
trainer.train()
# Student generates with O(1) per step (no KV cache!)
output = student.generate(input_ids, max_new_tokens=100)| Component | Description | Reference |
|---|---|---|
CayleyTransform |
Closed-form orthogonal matrix: Q = (I-A)(I+A)⁻¹ | Eq. 9 |
SparseOrthogonalMixture |
H^res = Σ αᵢ(x)·Qᵢ with ρ ≤ 1 | Eq. 7, Prop. 1 |
FactorizedKVCache |
Low-rank compression: x̄ ≈ UV^T | Eq. 14 |
AdaptiveRankSelector |
Gumbel-Softmax rank selection | Eq. 16 |
SHCBlock |
Complete block with triple routing | Algorithm 1 |
| Size | Hidden | Layers | Heads | Parameters |
|---|---|---|---|---|
| 500M | 1024 | 24 | 16 | ~500M |
| 1B | 2048 | 24 | 16 | ~1B |
| 3B | 2560 | 32 | 32 | ~3B |
| 7B | 4096 | 32 | 32 | ~7B |
shc/
├── __init__.py # Package init
├── requirements.txt # Dependencies
├── configs/
│ └── config.py # Model/training configs
├── layers/
│ ├── cayley.py # Cayley transform
│ ├── sparse_mixture.py # Sparse orthogonal routing
│ ├── factorized_cache.py # KV cache compression
│ └── adaptive_rank.py # Rank selection
├── blocks/
│ ├── attention.py # Multi-head attention + RoPE
│ ├── feedforward.py # SwiGLU FFN
│ └── shc_block.py # Complete SHC block
├── models/
│ ├── embeddings.py # Token/positional embeddings
│ ├── transformer.py # SHCTransformer
│ └── ssm_student.py # SSM for O(L) inference
├── training/
│ ├── distributed.py # DDP/FSDP utilities
│ ├── optimizer.py # Adam + cosine scheduler
│ ├── trainer.py # Training loop
│ └── distillation.py # Teacher→student distillation
├── data/
│ ├── dataset.py # Dataset classes
│ └── dataloader.py # Distributed data loading
├── evaluation/
│ ├── metrics.py # PPL, accuracy, F1, BLEU
│ ├── benchmarks.py # BBH, GSM8K, MMLU
│ └── profiler.py # Efficiency profiling
└── scripts/
├── train.py # Training CLI
└── evaluate.py # Evaluation CLI
# Spectral norm is bounded by construction
routing = SparseOrthogonalMixture(n=4, k=2, hidden_dim=768)
H_res = routing(x)
spectral_norm = routing.get_spectral_norm(x) # Always ≤ 1.0from shc.training import setup_distributed, wrap_model_ddp
# Automatic distributed setup
rank, local_rank, world_size = setup_distributed()
# DDP for data parallelism
model = wrap_model_ddp(model)
# Or FSDP for memory efficiency
model = wrap_model_fsdp(model, mixed_precision=True)from shc.layers import FactorizedKVCache
cache = FactorizedKVCache(n=4, d=768, r=1)
compressed = cache.compress(x_bar) # 4×768 → 1 scalar
reconstructed = cache.decompress(compressed) # 99% accuratefrom shc.evaluation import RoutingAnalyzer
analyzer = RoutingAnalyzer(model)
analyzer.analyze_batch(input_ids)
stats = analyzer.get_summary()
# {'spectral_norms': {'mean': 0.98, 'max': 1.0}, 'mixing_entropy': {...}}Target performance (from paper):
| Benchmark | SHC | MHC | DenseRes |
|---|---|---|---|
| BBH (23 tasks) | 42.3% | 42.1% | 40.8% |
| GSM8K | 28.7% | 28.5% | 27.2% |
| MMLU (5-shot) | 45.2% | 45.0% | 44.1% |
Efficiency gains:
- 16× speedup in routing computation
- 3.3× reduction in KV cache memory
- <1% overhead vs baseline Transformer
- Python 3.9+
- PyTorch 2.0+
- CUDA 11.8+ (for GPU training)
See shc/requirements.txt for full dependencies.
@article{shc2026,
title={Sparse Selective Hyper-Connections: A Unified Framework for
Stable and Efficient Deep Residual Learning},
author={...},
journal={...},
year={2026}
}MIT License - see LICENSE for details.