2021/01/08: This new version supports pruning with multi-GPU training. Code for pruning the torchvision standard ResNet-50 is released. The old version is moved into the "deprecated" directory.
This repository contains the codes for the following CVPR-2019 paper
Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure.
This demo will show you how to prune ResNet-50 on ImageNet with multiple GPUs (Distributed Data parallel) and ResNet-56 on CIFAR-10.
The results reproduced on the torchvision version of ResNet-50 (FLOPs=4.09B, top1-accuracy=76.15%) are
Final width | FLOPs reduction | Top-1 accuracy | Download |
---|---|---|---|
Original torchvision model | - | 76.15 | - |
Internal layers 70% | 36% | 75.94 | https://drive.google.com/file/d/1kFyc8xH2bRAi-e3v1iC529hTLBIVASGa/view?usp=sharing |
Internal layers 60% | 46% | 75.80 | https://drive.google.com/file/d/1_2tWF-St06KVj49c8yLrAlWUv8fv-LLk/view?usp=sharing |
Internal layers 50% | 56% | 75.29 | https://drive.google.com/file/d/1BndZeq3QkMOAE3wLfltt5SzCJwVF9PLV/view?usp=sharing |
Citation:
@inproceedings{ding2019centripetal,
title={Centripetal sgd for pruning very deep convolutional networks with complicated structure},
author={Ding, Xiaohan and Ding, Guiguang and Guo, Yuchen and Han, Jungong},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={4943--4953},
year={2019}
}
Filter pruning, a.k.a. network slimming or channel pruning, aims to remove some filters from a CNN so as to slim it with acceptable performance drop. We seek to make some filters increasingly close and eventually identical for network slimming. To this end, we propose Centripetal SGD (C-SGD), a novel optimization method, which can train several filters to collapse into a single point in the parameter hyperspace. When the training is completed, the removal of the identical filters can trim the network with NO performance loss, thus no finetuning is needed. By doing so, we have partly solved an open problem of constrained filter pruning on CNNs with complicated structure, where some layers must be pruned following others.
-
Enter this directory.
-
Make a soft link to your ImageNet directory, which contains "train" and "val" directories.
ln -s YOUR_PATH_TO_IMAGENET imagenet_data
- Set the environment variables.
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
- Download the official torchvision model, rename the parameters in our namestyle, and save the weights to "torchvision_res50.hdf5".
python transform_torchvision.py
- Run Centripetal SGD to prune the internal layers of ResNet-50 to 70% of the original width, then 60%, then 50%, 40%, 30%.
python -m torch.distributed.launch --nproc_per_node=8 csgd/do_csgd.py -a sres50 -i 0
python -m torch.distributed.launch --nproc_per_node=8 csgd/do_csgd.py -a sres50 -i 1
python -m torch.distributed.launch --nproc_per_node=8 csgd/do_csgd.py -a sres50 -i 2
python -m torch.distributed.launch --nproc_per_node=8 csgd/do_csgd.py -a sres50 -i 3
python -m torch.distributed.launch --nproc_per_node=8 csgd/do_csgd.py -a sres50 -i 4
We train a ResNet-56 (with 16-32-64 channels) and iteratively slim it into 13/16, 11/16 and 5/8 of the original width.
-
Enter this directory.
-
Make a soft link to your CIFAR-10 directory. If the dataset is not found in the directory, it will be automatically downloaded.
ln -s YOUR_PATH_TO_CIFAR cifar10_data
- Set the environment variables.
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0
- Run Centripetal SGD to train a base ResNet-56, then globally slim it into 13/16, 11/16, 5/8 of the original width.
python csgd/do_csgd.py -a src56 -i 0
python csgd/do_csgd.py -a src56 -i 1
python csgd/do_csgd.py -a src56 -i 2
Download any of the models above, and run like
python ndp_test.py sres50 csgd_res50_internal70.hdf5
The model can be used for your own tasks like detection and segmentation as usual.
For any conv net, the width of every conv layer is defined by an array named "deps". For example, the original deps of ResNet-50 is
RESNET50_ORIGIN_DEPS_FLATTENED = [64,256,64,64,256,64,64,256,64,64,256,512,128,128,512,128,128,512,128,128,512,128,128,512,
1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,256, 256, 1024,
2048,512, 512, 2048,512, 512, 2048,512, 512, 2048]
Note that we build the projection (1x1 conv shortcut) layer before the parallel residual block (L61 in stagewise_resnet.py), so that its width (256) preceds the widths of the three layers of the residual block (64, 64, 256). In do_csgd.py, "itr_deps" defines the target structure of the pruned model for each iteration. So if you want to customize the final width by pruning every internal layer by 42% and the other troublesome layers by 39%, do something like this
final_deps = np.array(RESNET50_ORIGIN_DEPS_FLATTENED)
for i in range(1, len(RESNET50_ORIGIN_DEPS_FLATTENED)): # starts from 0 if you want to prune the first layer
if i in RESNET50_INTERNAL_KERNEL_IDXES:
final_deps[i] = int(0.58 * final_deps[i])
else:
final_deps[i] = int(0.61 * final_deps[i])
itr_deps = [final_deps] # if you want to do it in one iteration. You can define a series of deps to do it in several iterations, like "generate_itr_to_target_deps_by_schedule_vector".
Google Scholar Profile: https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en
My open-sourced papers and repos:
State-of-the-art channel pruning (preprint, 2020): Lossless CNN Channel Pruning via Gradient Resetting and Convolutional Re-parameterization (https://github.com/DingXiaoH/ResRep)
CNN component (ICCV 2019): ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks (https://github.com/DingXiaoH/ACNet)
Channel pruning (CVPR 2019): Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure (https://github.com/DingXiaoH/Centripetal-SGD)
Channel pruning (ICML 2019): Approximated Oracle Filter Pruning for Destructive CNN Width Optimization (https://github.com/DingXiaoH/AOFP)
Unstructured pruning (NeurIPS 2019): Global Sparse Momentum SGD for Pruning Very Deep Neural Networks (https://github.com/DingXiaoH/GSM-SGD)