This repository contains Pytorch source code for arXiv paper SASSHA: Sharpness-aware Adaptive Second-order Optimization With Stable Hessian Approximation by Dahun Shin*, Dongyeop Lee*, Jinseok Chung, and Namhoon Lee.
SASSHA is a novel second-order method designed to enhance generalization by explicitly reducing sharpness of the solution, while stabilizing the computation of approximate Hessians along the optimization trajectory.
This Pytorch implementation supports various tasks, including image classification, finetuning, and label noise experiments.
For a detailed explanation of the SASSHA algorithm, please refer to our paper.
First, clone our repository to your local system:
git clone https://github.com/LOG-postech/Sassha.git
cd Sassha
We recommend using Anaconda to set up the environment and install all necessary dependencies:
conda create -n "sassha" python=3.9
conda activate sassha
pip install -r requirements.txt
Ensure you are using Python 3.9 or later.
Navigate to the example folder of your choice. For instance, to run an image classification experiment:
cd image_classification
Now, train the model with the following command:
python train.py --workers 4 --dataset imagenet -a resnet50 --epochs 90 -b 256 \
--LRScheduler multi_step --lr-decay-epoch 30 60 --lr-decay 0.1 \
--optimizer sassha \
--lr 0.3 --wd 1e-4 --rho 0.2 --lazy_hessian 10 --seed 0 \
--project_name sassha \
{enter/your/imagenet-folder/with/train_and_val_data}
Here, enter the path to imagenet datasets in {enter/your/imagenet-folder/with/train_and_val_data}
.
SASSHA is fully compatible with multi-GPU environments for distributed training. Use the following command to train a model across multiple GPUs on a single node:
python train.py --dist-url 'tcp://127.0.0.1:23456' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 \
--workers 4 --dataset imagenet -a vit_b_32 --epochs 90 -b 1024 \
--LRScheduler cosine --warmup_epochs 8 \
--optimizer sassha \
--lr 0.6 --wd 2e-4 --rho 0.25 --lazy_hessian 10 --eps 1e-6 --seed 0 \
--project_name sassha \
{enter/your/imagenet-folder/with/train_and_val_data}
Ensure that NCCL is properly configured on your system and that your GPUs are available before running the script.
Configurations used in our paper are provided as shell scrips in each example folder.
- cuda 11.6.2
- python 3.9
SASSHA can be imported and used as follows:
from optimizers import SASSHA
...
# Initialize your model and optimizer
model = YourModel()
optimizer = SASSHA(model.parameters(), ...)
...
# training loop
for input, output in data:
# first forward-backward pass
loss = loss_function(output, model(input))
loss.backward()
optimizer.perturb_weights(zero_grad=True)
# second forward-backward pass
loss_function(output, model(input)).backward(create_graph=True)
optimizer.unperturb()
optimizer.step()
optimizer.zero_grad()
...
@article{shin2025sassha,
title={SASSHA: Sharpness-aware Adaptive Second-order Optimization With Stable Hessian Approximation},
author={Shin, Dahun and Lee, Dongyeop and Chung, Jinseok and Lee, Namhoon},
journal={arXiv preprint arXiv:2502.18153},
year={2025}
}