Skip to content

rahvis/shc

Repository files navigation

SHC: Sparse Selective Hyper-Connections

PyPI version Python 3.9+ PyTorch 2.0+ License: MIT Tests Documentation

A PyTorch implementation of Sparse Selective Hyper-Connections for stable and efficient deep residual learning.

Overview

SHC replaces traditional residual connections with sparse mixtures of orthogonal routing matrices, providing:

  • Bounded spectral norm (ρ ≤ 1): Guarantees training stability
  • 16× faster routing: Via closed-form Cayley transform (vs. Sinkhorn iteration)
  • 3.3× KV cache reduction: Through learned low-rank factorization
  • O(L) inference: Optional SSM distillation for linear-time generation

Installation

# Clone repository
git clone https://github.com/rahvis/shc.git
cd shc

# Create virtual environment
python -m venv venv
source venv/bin/activate  # Linux/Mac
# or: venv\Scripts\activate  # Windows

# Install dependencies
pip install -r shc/requirements.txt

# Install package in development mode
pip install -e .

Quick Start

Basic Usage

from shc.models import SHCTransformer, get_config

# Create model with predefined config
config = get_config('500m')  # Options: '500m', '1b', '3b', '7b'
model = SHCTransformer(config)

# Forward pass
import torch
input_ids = torch.randint(0, 32000, (2, 512))
logits = model(input_ids)

# Generate text
output = model.generate(
    input_ids[:, :10],  # prompt
    max_new_tokens=100,
    temperature=0.7,
    top_p=0.9,
)

Training

# Single GPU
python -m shc.scripts.train --model_size 500m --output_dir ./output

# Multi-GPU with DDP
torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 3b \
    --batch_size 32 \
    --learning_rate 3e-4

# FSDP for 7B+ models (memory efficient)
torchrun --nproc_per_node=8 -m shc.scripts.train \
    --model_size 7b \
    --use_fsdp \
    --mixed_precision bf16

Evaluation

# Run benchmarks
python -m shc.scripts.evaluate \
    --model_path ./output/final \
    --benchmarks bbh gsm8k mmlu

# Efficiency profiling
python -m shc.scripts.evaluate \
    --model_path ./output/final \
    --profile \
    --analyze_routing

SSM Distillation

from shc.models import SHCTransformer, SSMStudent
from shc.training import DistillationTrainer, DistillationConfig

# Load trained teacher
teacher = SHCTransformer.from_pretrained('path/to/teacher')

# Create student matching teacher dimensions
student = SSMStudent.from_teacher_config(teacher.config)

# Distill
config = DistillationConfig(max_steps=10000)
trainer = DistillationTrainer(teacher, student, config, train_loader)
trainer.train()

# Student generates with O(1) per step (no KV cache!)
output = student.generate(input_ids, max_new_tokens=100)

Architecture

Core Components

Component Description Reference
CayleyTransform Closed-form orthogonal matrix: Q = (I-A)(I+A)⁻¹ Eq. 9
SparseOrthogonalMixture H^res = Σ αᵢ(x)·Qᵢ with ρ ≤ 1 Eq. 7, Prop. 1
FactorizedKVCache Low-rank compression: x̄ ≈ UV^T Eq. 14
AdaptiveRankSelector Gumbel-Softmax rank selection Eq. 16
SHCBlock Complete block with triple routing Algorithm 1

Model Configurations

Size Hidden Layers Heads Parameters
500M 1024 24 16 ~500M
1B 2048 24 16 ~1B
3B 2560 32 32 ~3B
7B 4096 32 32 ~7B

Project Structure

shc/
├── __init__.py              # Package init
├── requirements.txt         # Dependencies
├── configs/
│   └── config.py           # Model/training configs
├── layers/
│   ├── cayley.py           # Cayley transform
│   ├── sparse_mixture.py   # Sparse orthogonal routing
│   ├── factorized_cache.py # KV cache compression
│   └── adaptive_rank.py    # Rank selection
├── blocks/
│   ├── attention.py        # Multi-head attention + RoPE
│   ├── feedforward.py      # SwiGLU FFN
│   └── shc_block.py        # Complete SHC block
├── models/
│   ├── embeddings.py       # Token/positional embeddings
│   ├── transformer.py      # SHCTransformer
│   └── ssm_student.py      # SSM for O(L) inference
├── training/
│   ├── distributed.py      # DDP/FSDP utilities
│   ├── optimizer.py        # Adam + cosine scheduler
│   ├── trainer.py          # Training loop
│   └── distillation.py     # Teacher→student distillation
├── data/
│   ├── dataset.py          # Dataset classes
│   └── dataloader.py       # Distributed data loading
├── evaluation/
│   ├── metrics.py          # PPL, accuracy, F1, BLEU
│   ├── benchmarks.py       # BBH, GSM8K, MMLU
│   └── profiler.py         # Efficiency profiling
└── scripts/
    ├── train.py            # Training CLI
    └── evaluate.py         # Evaluation CLI

Key Features

1. Stable Training via Orthogonal Routing

# Spectral norm is bounded by construction
routing = SparseOrthogonalMixture(n=4, k=2, hidden_dim=768)
H_res = routing(x)
spectral_norm = routing.get_spectral_norm(x)  # Always ≤ 1.0

2. Efficient Multi-GPU Training

from shc.training import setup_distributed, wrap_model_ddp

# Automatic distributed setup
rank, local_rank, world_size = setup_distributed()

# DDP for data parallelism
model = wrap_model_ddp(model)

# Or FSDP for memory efficiency
model = wrap_model_fsdp(model, mixed_precision=True)

3. Memory-Efficient KV Cache

from shc.layers import FactorizedKVCache

cache = FactorizedKVCache(n=4, d=768, r=1)
compressed = cache.compress(x_bar)       # 4×768 → 1 scalar
reconstructed = cache.decompress(compressed)  # 99% accurate

4. Routing Analysis

from shc.evaluation import RoutingAnalyzer

analyzer = RoutingAnalyzer(model)
analyzer.analyze_batch(input_ids)
stats = analyzer.get_summary()
# {'spectral_norms': {'mean': 0.98, 'max': 1.0}, 'mixing_entropy': {...}}

Benchmarks

Target performance (from paper):

Benchmark SHC MHC DenseRes
BBH (23 tasks) 42.3% 42.1% 40.8%
GSM8K 28.7% 28.5% 27.2%
MMLU (5-shot) 45.2% 45.0% 44.1%

Efficiency gains:

  • 16× speedup in routing computation
  • 3.3× reduction in KV cache memory
  • <1% overhead vs baseline Transformer

Requirements

  • Python 3.9+
  • PyTorch 2.0+
  • CUDA 11.8+ (for GPU training)

See shc/requirements.txt for full dependencies.

Citation

@article{shc2026,
  title={Sparse Selective Hyper-Connections: A Unified Framework for 
         Stable and Efficient Deep Residual Learning},
  author={...},
  journal={...},
  year={2026}
}

License

MIT License - see LICENSE for details.

About

SHC: Sparse Selective Hyper-Connections

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages