Unofficial implementation of SNIP (ICLR 19) in PyTorch. SNIP is a single shot neural network prunning technique which prunes the network before training based on sensitivity of connections of the randomly initialized weights.
from snip_prunner import Prunner
from model import my_model
from loss_func import my_loss
prunner = Prunner(my_model, my_loss, train_dataloader)
prunned_model, masks = prunner.prun(compression_factor=0.9, num_batch_sampling=1)
"""
Now continue training prunned_model
as you would do in normal setup
"""Refer test_mnist.ipynb for experiments on MNIST
| Parameters / Batches | 1 | 10 |
|---|---|---|
| 90% | 97.74 | 97.70 |
| 75% | 97.79 | 97.79 |
| 50% | 97.74 | 97.67 |
| 10% | 96.69 | 96.69 |
| 2% | 93.01 | 93.69 |
Run experiments using ResNet Model on CIFAR 10