HesScale is built on top of Pytorch and BackPack. It allows for Hessian diagonals to backpropagate through the layers of the network.
python3.7 -m venv .hesscale
source .hesscale/bin/activate
python -m pip install --upgrade pip
pip install .
We added a couple of minimal examples in the examples
directory for easier understanding of how to use this package. Here is a minimal example:
import torch
from backpack import backpack, extend
from optimizers.adahesscale import AdaHesScale
hidden_units = 128
n_obs = 6
n_classes = 10
lr = 0.0004
batch_size = 1
model = torch.nn.Sequential(
torch.nn.Linear(n_obs, hidden_units),
torch.nn.Sigmoid(),
torch.nn.Linear(hidden_units, hidden_units),
torch.nn.Tanh(),
torch.nn.Linear(hidden_units, n_classes),
)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = AdaHesScale(model.parameters(), lr=lr)
extend(model)
extend(loss_func)
inputs = torch.randn((batch_size, n_obs))
target_class = torch.randint(0, n_classes, (batch_size,))
prediction = model(inputs)
optimizer.zero_grad()
loss = loss_func(prediction, target_class)
with backpack(optimizer.method):
loss.backward()
optimizer.step()
We appreciate any help to extend HesScale to recurrent neural networks. If you consider contributing, please fork the repo and create a pull request.
Distributed under the MIT License. See LICENSE
for more information.
To reproduce the experiments in the paper, you run the scripts in experiemnts
directory for reproducing the approximation-quality experiment and the computational-cost experiment. For reproducing the training plots, please refer to the our other repo that uses DeepOBS for complete reproduction of our optimization results.
If you use our code, please consider citing our paper too.
Elsayed, M., & Mahmood, A. R. (2022). HesScale: Scalable Computation of Hessian Diagonals. arXiv preprint arXiv:2210.11639.