Gated Attention Coding for Training High-performance and Efficient Spiking Neural Networks (AAAI24)
Xuerui Qiu, Rui-Jie Zhu, Yuhong Chou,Zhaorui Wang, Liang-Jian Deng, Guoqi Li
Institute of Automation, Chinese Academy of Sciences
University of Electronic Science and Technology of China
University of California, Santa Cruz
Xi'an Jiaotong University
🚀 🚀 🚀 News:
- Dec. 19, 2023: Release the code for training and testing.
- Dec. 17, 2023: Accepted as poster in AAAI2024.
Spiking neural networks (SNNs) are emerging as an energy-efficient alternative to traditional artificial neural networks (ANNs) due to their unique spike-based event-driven nature. Coding is crucial in SNNs as it converts external input stimuli into spatio-temporal feature sequences. However, most existing deep SNNs rely on direct coding that generates powerless spike representation and lacks the temporal dynamics inherent in human vision. Hence, we introduce Gated Attention Coding (GAC), a plug-and-play module that leverages the multi-dimensional gated attention unit to efficiently encode inputs into powerful representations before feeding them into the SNN architecture. GAC functions as a preprocessing layer that does not disrupt the spike-driven nature of the SNN, making it amenable to efficient neuromorphic hardware implementation with minimal modifications. Through an observer model theoretical analysis, we demonstrate GAC's attention mechanism improves temporal dynamics and coding efficiency. Experiments on CIFAR10/100 and ImageNet datasets demonstrate that GAC achieves state-of-the-art accuracy with remarkable efficiency. Notably, we improve top-1 accuracy by 3.10% on CIFAR100 with only 6-time steps and 1.07% on ImageNet while reducing energy usage to 66.9% of the previous works. To our best knowledge, it is the first time to explore the attention-based dynamic coding scheme in deep SNNs, with exceptional effectiveness and efficiency on large-scale datasets.
The Following Setup is tested and it is working:
- Python 3.7
- Pytorch 1.8.0
- Cuda 10.2
-
use a triangle-like surrogate gradient
ZIF
inmodels/layer.py
for step function forward and backward. -
The 0-th and 1-th dimension of snn layer's input and output are batch-dimension and time-dimension.
-
The most straightforward way of training higher quality models is by increasing their size. In this work, we would like to see that deepening network structures could get rid of the degradation problem and always be a trustworthy way to achieve satisfying accuracy for the direct training of SNNs.
-
This repository contains the source code for the training of our MS-ResNet on ImageNet. The models are defined in
models/MS_ResNet.py
.
- Change the data paths
vardir,traindir
to the image folders of ImageNet/CIFAR dataset. - For CIFAR dataset, to train the model, please run
run.sh
. - For ImageNet dataset, to train the model, please run
run.sh
orCUDA_VISIBLE_DEVICES=GPU_IDs python -m torch.distributed.launch --master_port=1234 --nproc_per_node=NUM_GPU_USED train_amp.py -net resnet34 -b 256 -lr 0.1
.-net
option supportsresnet18/34
.
model | T | Params(M) | Top-1 Acc on C10/C100 |
---|---|---|---|
GAC-MSResNet-18 | 6 | 12.63 | 96.46/80.45 |
GAC-MSResNet-18 | 4 | 12.63 | 96.24/79.83 |
GAC-MSResNet-18 | 2 | 12.63 | 96.18/78.92 |
model | T | Params(M) | Power(mj) | Top-1 Acc |
---|---|---|---|---|
GAC-MSResNet-18 | 4 | 11.82 | 1.49 | 64.05 |
GAC-MSResNet-18 | 6 | 11.82 | 2.34 | 65.14 |
GAC-MSResNet-34 | 4 | 21.93 | 2.20 | 69.77 |
GAC-MSResNet-34 | 6 | 21.93 | 3.38 | 70.42 |
- Due to the size limit of the uploaded file, we currently open source the experimental weight of T=6 on the CIFAR100 dataset CIFAR100_T=6.pth.
- The test code can be run according to the following requirements
- Change the data paths
vardir,traindir
the image folders of CIFAR100 dataset. - Then run
python test.py
in CIFAR file.
- Other weight files will be open source soon.
Our experimental weight of T=6 on the CIFAR100 dataset can be found in
link:https://pan.baidu.com/s/1jue9S9hAKFeYCs2iGWXUxg
code:4567
- use
PyTorch
to load the CIFAR10 and CIFAR100 dataset. Tree in./data/
.
.
├── cifar-100-python
├── cifar-10-batches-py
ImageNet with the following folder structure, you can extract imagenet by this script.
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
@inproceedings{qiu2024gated,
title={Gated attention coding for training high-performance and efficient spiking neural networks},
author={Qiu, Xuerui and Zhu, Rui-Jie and Chou, Yuhong and Wang, Zhaorui and Deng, Liang-jian and Li, Guoqi},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={38},
number={1},
pages={601--610},
year={2024}
}
For help or issues using this git, please submit a GitHub issue.
For other communications related to this git, please contact qiuxuerui2024@ia.ac.cn
.