-
Notifications
You must be signed in to change notification settings - Fork 3
/
prune.py
41 lines (36 loc) · 1.34 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from tqdm import tqdm
import torch
import numpy as np
import torch.nn as nn
def prune_loop(model, loss, pruner, dataloader, device, sparsity, schedule, scope, epochs,
reinitialize=False, train_mode=False, shuffle=False, invert=False):
r"""Applies score mask loop iteratively to a final sparsity level.
"""
# Set model to train or eval mode
model.train()
if not train_mode:
model.eval()
# Prune model
for epoch in tqdm(range(epochs)):
pruner.score(model, loss, dataloader, device)
if schedule == 'exponential':
sparse = sparsity**((epoch + 1) / epochs)
elif schedule == 'linear':
sparse = 1.0 - (1.0 - sparsity)*((epoch + 1) / epochs)
elif schedule == 'expinv':
sparse = 1.0 - (1.0 - sparsity)**(epochs/(epoch + 1))
# Invert scores
if invert:
pruner.invert()
pruner.mask(sparse, scope)
# Reainitialize weights
if reinitialize:
model._initialize_weights()
# Shuffle masks
if shuffle:
pruner.shuffle()
# Confirm sparsity level
remaining_params, total_params = pruner.stats()
if np.abs(remaining_params - total_params*sparsity) >= 5:
print("ERROR: {} prunable parameters remaining, expected {}".format(remaining_params, total_params*sparsity))
quit()