High-performance, extensible, chainable optimizers for PyTorch.
- Lightning-Fast Training: Batched
foreach
operations deliver significant speedups on large models. - Adaptive & Extensible: Built-in AdamW, RMSprop, Schedule-Free algorithms, and PaLM-inspired schedules.
- Plug-and-Play: Drop-in replacements for
torch.optim
with seamless integration. - Customizable: Chainable API lets you compose optimizers and transforms (MARS correction, cautious updates, orthogonal updates).
- Battle-Tested: Extensive benchmarks and real-world examples included.
- Foreach-based optimizers:
ForeachAdamW
,ForeachRMSprop
,ForeachSFAdamW
,Muon
,ADOPT
,MSAM
, … - Schedule-Free optimizers with dynamic learning rate adaptation.
- Advanced update rules: MARS correction, cautious updates, PaLM beta2 scheduling.
- Chainable transforms for custom optimization recipes.
- Comprehensive benchmark suite (
benchmark/
). - Detailed documentation and example-driven tutorials.
Install:
pip install heavyball
Basic usage:
import torch
from torch import nn
from heavyball import ForeachAdamW
model = nn.Sequential(
nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10)
)
optimizer = ForeachAdamW(model.parameters(), lr=1e-3)
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
Reproduce benchmarks with:
python3 -m benchmark.run_all_benchmarks --opt ForeachSOAP --opt LaProp --opt AdamW --opt Muon --opt ForeachCachedNewtonPSGD --opt RMSprop --opt OrthoLaProp --opt ForeachSFAdamW --opt ForeachADOPT --opt LaPropOrtho --opt CachedPSGDKron --opt SignLaProp --opt ForeachSOLP --opt PSGDLRA --opt NewtonPSGDLRA --opt NewtonHybrid2PSGDKron --opt NewtonHybrid2PSGDLRA --opt mars-NewtonHybrid2PSGDLRA --opt MSAMLaProp --opt mars-adaptive-NewtonHybrid2PSGDKron --opt mars-ortho-NewtonHybrid2PSGDKron --opt MuonLaProp --opt mars-unscaled-NewtonHybrid2PSGDKron --opt mars-NewtonHybrid2PSGDKron --opt cautious-AdamW --opt unscaled_cautious-AdamW --opt mars-AdamW --dtype float32 --steps 1000000 --trials 1000 --parallelism 256 --seeds 1 --difficulties trivial --difficulties easy --difficulties medium --difficulties hard --difficulties extreme --difficulties nightmare --timeout 2880
We welcome contributions! Please check the issue tracker and follow these steps:
- Fork the repo and create a feature branch.
- Install dev dependencies:
pip install -e .[dev]
. - Run tests:
pytest
. - Submit a pull request.
BSD 3-Clause — see the LICENSE file.
Made by the HeavyBall team.