Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
dingxiaohan committed Jan 8, 2021
2 parents 8c9c358 + f9d83b4 commit 4d4a568
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Centripetal-SGD

Update: Pytorch implemenation released.
2020/12/31: Will be updated in several days with multi-GPU pytorch implementation (Distributed Data Parallel) and pruning scripts on the standard torchvision ResNet-50 (76.15% accuracy). The results are pretty good.

Note: The critical codes for C-SGD training (csgd/csgd_train.py) and pruning (csgd/csgd_prune.py) have been refactored and cleaned, such that the readability has significantly improved. The Tensorflow codes also work, but I would not suggest you read them. A little trick: using smaller centripetal strength on the scaling factor of BN improves the performance in some of the cases.
Update: Pytorch implementation released. Fixed a bug of Pytorch implementation in csgd/csgd_prune.py, which was related to pruning the last conv layer which is followed by an FC layer. This bug only resulted in an error if the last-layer feature maps were flattened as input to FC. For models with Global Average Pooling like ResNet-56, this bug was harmless. The critical codes for C-SGD training (csgd/csgd_train.py) and pruning (csgd/csgd_prune.py) have been refactored and cleaned, such that the readability has significantly improved. The Tensorflow codes also work, but I would not suggest you read them. A little trick: using smaller centripetal strength on the scaling factor of BN improves the performance in some of the cases.

This repository contains the codes for the following CVPR-2019 paper

Expand Down Expand Up @@ -76,12 +76,14 @@ dxh17@mails.tsinghua.edu.cn

Google Scholar Profile: https://scholar.google.com/citations?user=CIjw0KoAAAAJ&hl=en

My open-sourced papers and repos:
My open-sourced papers and repos:

**State-of-the-art** channel pruning (preprint, 2020): [Lossless CNN Channel Pruning via Gradient Resetting and Convolutional Re-parameterization](https://arxiv.org/abs/2007.03260) (https://github.com/DingXiaoH/ResRep)

CNN component (ICCV 2019): [ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks](http://openaccess.thecvf.com/content_ICCV_2019/papers/Ding_ACNet_Strengthening_the_Kernel_Skeletons_for_Powerful_CNN_via_Asymmetric_ICCV_2019_paper.pdf) (https://github.com/DingXiaoH/ACNet)

Channel pruning (CVPR 2019): [Centripetal SGD for Pruning Very Deep Convolutional Networks with Complicated Structure](http://openaccess.thecvf.com/content_CVPR_2019/html/Ding_Centripetal_SGD_for_Pruning_Very_Deep_Convolutional_Networks_With_Complicated_CVPR_2019_paper.html) (https://github.com/DingXiaoH/Centripetal-SGD)

Channel pruning (ICML 2019): [Approximated Oracle Filter Pruning for Destructive CNN Width Optimization](http://proceedings.mlr.press/v97/ding19a.html) (https://github.com/DingXiaoH/AOFP)

Unstructured pruning (NeurIPS 2019): [Global Sparse Momentum SGD for Pruning Very Deep Neural Networks](https://arxiv.org/pdf/1909.12778.pdf) (https://github.com/DingXiaoH/GSM-SGD)
Unstructured pruning (NeurIPS 2019): [Global Sparse Momentum SGD for Pruning Very Deep Neural Networks](http://papers.nips.cc/paper/8867-global-sparse-momentum-sgd-for-pruning-very-deep-neural-networks.pdf) (https://github.com/DingXiaoH/GSM-SGD)
12 changes: 7 additions & 5 deletions csgd/csgd_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,20 @@ def handle_vecs(key_name):
num_filters = k_value.shape[0]
fc_neurons_per_conv_kernel = follow_kernel_value.shape[1] // num_filters
print('{} filters, {} neurons per kernel'.format(num_filters, fc_neurons_per_conv_kernel))
base = np.arange(0, fc_neurons_per_conv_kernel * num_filters, num_filters)

for clst in clusters:
if len(clst) == 1:
continue
for i in clst[1:]:
fc_idx_to_delete.append(base + i)
fc_idx_to_delete.append(np.arange(i * fc_neurons_per_conv_kernel,
(i+1) * fc_neurons_per_conv_kernel))
to_concat = []
for i in clst:
corresponding_neurons_idx = base + i
corresponding_neurons_idx = np.arange(i * fc_neurons_per_conv_kernel,
(i+1) * fc_neurons_per_conv_kernel)
to_concat.append(np.expand_dims(follow_kernel_value[:, corresponding_neurons_idx], axis=0))
summed = np.sum(np.concatenate(to_concat, axis=0), axis=0)
reserved_idx = base + clst[0]
reserved_idx = np.arange(clst[0] * fc_neurons_per_conv_kernel, (clst[0]+1) * fc_neurons_per_conv_kernel)
follow_kernel_value[:, reserved_idx] = summed
if len(fc_idx_to_delete) > 0:
follow_kernel_value = delete_or_keep(follow_kernel_value, np.concatenate(fc_idx_to_delete, axis=0), axis=1)
Expand All @@ -115,4 +117,4 @@ def handle_vecs(key_name):
result['deps'] = new_deps

print('save {} values to {} after pruning'.format(len(result), save_file))
save_hdf5(result, save_file)
save_hdf5(result, save_file)

0 comments on commit 4d4a568

Please sign in to comment.