Skip to content

Commit b3f7ff3

Browse files
committed
Fixed variance redistribution function.
1 parent 93db164 commit b3f7ff3

File tree

4 files changed

+37
-22
lines changed

4 files changed

+37
-22
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@
2727
## Features:
2828
- Added FP16 support. Any model can now be run in 16-bit by passing the [apex](https://github.com/NVIDIA/apex) `FP16_Optimizer` into the `Masking` class and replacing `loss.backward()` with `optimizer.backward(loss)`.
2929
- Added adapted [Dynamic Sparse Reparameterization](https://arxiv.org/abs/1902.05967) [codebase](https://github.com/IntelAI/dynamic-reparameterization) that works with sparse momentum.
30-
- Added modular architecture for growth/prune/redistribution algorithms which is decoupled from the main library. This enables you to write your own prune/growth/redistribution algorithms without touched the library internals. A tutorial on how to add your own functions was also added: [How to Add Your Own Algorithms](How_to_add_your_own_algorithms.md]).
30+
- Added modular architecture for growth/prune/redistribution algorithms which is decoupled from the main library. This enables you to write your own prune/growth/redistribution algorithms without touched the library internals. A tutorial on how to add your own functions was also added: [How to Add Your Own Algorithms](How_to_add_your_own_algorithms.md).

How_to_add_your_own_algorithms.md

+20-14
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,33 @@ masking.total_removed
5757

5858
Here I added two example extensions for redistribution and pruning. These two examples look at the variance of the gradient. If we look at weights with high and low variance in their gradients over time, then we can have the following interpretations.
5959

60-
For high variance weights, we can have two perspectives. The first one would assume that weights with high variance are unable to model the interactions in the inputs to classify the outputs due to a lack of capacity. For example a weight might have a problem to be useful for both the digit 0 and digit 7 when classifying MNIST and thus has high variance between these examples. If we add capacity to high variance layers, then we should reduce the variance over time (one weight for 7 one weight for 0). According to this perspective we want to add more parameters to layers with high average variance. In other words, we want to redistribute pruned parameters to layers with high gradient variance.
60+
For high variance weights, we can have two perspectives. The first one would assume that weights with high variance are unable to model the interactions in the inputs to classify the outputs due to a lack of capacity. For example a weight might have a problem to be useful for both the digit 0 and digit 7 when classifying MNIST and thus has high variance between these examples. If we add capacity to high variance layers, then we should reduce the variance over time meaning the new weights can now fully model the different classes (one weight for 7 one weight for 0). According to this perspective we want to add more parameters to layers with high average variance. In other words, we want to redistribute pruned parameters to layers with high gradient variance.
6161

62-
The second perspective is a "potential of be useful" perspective. Here we see weights with high variance as having "potential to do the right classification, but they might just not have found the right decision boundary between classes yet". For example, a weight might have problems being useful for both the digit 7 and 0 but overtime it can find a feature which is useful for both classes. Thus gradient variance should reduce over time as features become more stable. If we take this perspective then it is important to keep some medium-to-high variance weights. Low variance weights have "settled in" and follow the gradient for a specific set of classes. These weights will not change much anymore while high variance weights might change a lot. So high variance weights might have "potential" while the potential of low variance weights is easily assessed by looking at the magnitude of that weights. Thus we might improve pruning if we look at both the variance of the gradient _and_ the magnitude of weights. You can find these examples in ['mnist_cifar/extensions.py']('mnist_cifar/extensions.py').
62+
The second perspective is a "potential of be useful" perspective. Here we see weights with high variance as having "potential to do the right classification, but they might just not have found the right decision boundary between classes yet". For example, a weight might have problems being useful for both the digit 7 and 0 but overtime it can find a feature which is useful for both classes. Thus gradient variance should reduce over time as features become more stable. If we take this perspective then it is important to keep some medium-to-high variance weights. Low variance weights have "settled in" and follow the gradient for a specific set of classes. These weights will not change much anymore while high variance weights might change a lot. So high variance weights might have "potential" while the potential of low variance weights is easily assessed by looking at the magnitude of that weights. Thus we might improve pruning if we look at both the variance of the gradient _and_ the magnitude of weights. You can find these examples in ['mnist_cifar/extensions.py']('sparse_learning/mnist_cifar/extensions.py').
6363

6464
### Implementation
6565

6666
```python
6767
def variance_redistribution(masking, name, weight, mask):
6868
'''Return the mean variance of existing weights.
6969
70-
Higher variance means the layer does not have enough
71-
capacity to model the inputs with the number of current weights.
72-
If weights stabilize this means that some weights might
73-
be useless/not needed.
70+
Intuition: Higher gradient variance means a layer does not have enough
71+
capacity to model the inputs with the current number of weights.
72+
Thus we want to add more weights if we have higher variance.
73+
If variance of the gradient stabilizes this means
74+
that some weights might be useless/not needed.
7475
'''
75-
layer_importance = torch.var(weight.grad[mask.byte()]).mean().item()
76+
# Adam calculates the running average of the sum of square for us
77+
# This is similar to RMSProp.
78+
if 'exp_avg_sq' not in masking.optimizer.state[weight]:
79+
print('Variance redistribution requires the adam optimizer to be run!')
80+
raise Exception('Variance redistribution requires the adam optimizer to be run!')
81+
iv_adam_sumsq = torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq'])
82+
83+
layer_importance = iv_adam_sumsq[mask.byte()].mean().item()
7684
return layer_importance
7785

86+
7887
def magnitude_variance_pruning(masking, mask, weight, name):
7988
''' Prunes weights which have high gradient variance and low magnitude.
8089
@@ -135,13 +144,10 @@ Running 10 additional iterations (add `--iters 10`) of our new method with 5% we
135144
```bash
136145
python get_results_from_logs.py
137146

138-
Test set results for log: ./logs/lenet5_0.05_520776ed.log
139-
Arguments:
140-
augment=False, batch_size=100, bench=False, data='mnist', decay_frequency=25000, dense=False, density=0.05, epochs=100, fp16=False, growth='momentum', iters=10, l1=0.0, l2=0.0005, log_interval=100, lr=0.001, model='lenet5', momentum=0.9, no_cuda=False, optimizer='adam', prune='magnitude_variance', prune_rate=0.5, redistribution='variance', resume=None, save_features=False, save_model='./models/model.pt', seed=17, start_epoch=1, test_batch_size=100, valid_split=0.1, verbose=True
147+
Accuracy. Median: 0.99300, Mean: 0.99300, Standard Error: 0.00019, Sample size: 11, 95% CI: (0.99262,0.99338)
148+
Error. Median: 0.00700, Mean: 0.00700, Standard Error: 0.00019, Sample size: 11, 95% CI: (0.00662,0.00738)
149+
Loss. Median: 0.02200, Mean: 0.02175, Standard Error: 0.00027, Sample size: 11, 95% CI: (0.02122,0.02228)
141150

142-
Accuracy. Mean: 0.99349, Standard Error: 0.00013, Sample size: 11, 95% CI: (0.99323,0.99375)
143-
Error. Mean: 0.00651, Standard Error: 0.00013, Sample size: 11, 95% CI: (0.00625,0.00677)
144-
Loss. Mean: 0.02078, Standard Error: 0.00035, Sample size: 11, 95% CI: (0.02010,0.02146)
145151
```
146152

147-
Sparse momentum achieves an error of 0.0069 for this setting and the lower 95% confidence interval is 0.00649. Thus for this setting our results overlap with the confidence intervals of sparse momentum. Thus our new variance method is _as good or better_ than sparse momentum for this particular problem (Caffe LeNet-5 with 5% weights on MNIST).
153+
Sparse momentum achieves an error of 0.0069 for this setting and the upper 95% confidence interval is 0.00739. Thus for this setting our results overlap with the confidence intervals of sparse momentum. Thus our new variance method is _as good_ as sparse momentum for this particular problem (Caffe LeNet-5 with 5% weights on MNIST).

mnist_cifar/extensions.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,20 @@ def your_redistribution(masking, name, weight, mask):
4848
def variance_redistribution(masking, name, weight, mask):
4949
'''Return the mean variance of existing weights.
5050
51-
Higher variance means the layer does not have enough
52-
capacity to model the inputs with the number of current weights.
53-
If weights stabilize this means that some weights might
54-
be useless/not needed.
51+
Higher gradient variance means a layer does not have enough
52+
capacity to model the inputs with the current number of weights.
53+
Thus we want to add more weights if we have higher variance.
54+
If variance of the gradient stabilizes this means
55+
that some weights might be useless/not needed.
5556
'''
56-
layer_importance = torch.var(weight.grad[mask.byte()]).mean().item()
57+
# Adam calculates the running average of the sum of square for us
58+
# This is similar to RMSProp.
59+
if 'exp_avg_sq' not in masking.optimizer.state[weight]:
60+
print('Variance redistribution requires the adam optimizer to be run!')
61+
raise Exception('Variance redistribution requires the adam optimizer to be run!')
62+
iv_adam_sumsq = torch.sqrt(masking.optimizer.state[weight]['exp_avg_sq'])
63+
64+
layer_importance = iv_adam_sumsq[mask.byte()].mean().item()
5765
return layer_importance
5866

5967

mnist_cifar/main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def main():
159159

160160
args = parser.parse_args()
161161
setup_logger(args)
162+
print_and_log(args)
162163

163164
if args.fp16:
164165
try:
@@ -202,13 +203,13 @@ def main():
202203

203204
# add custom prune/growth/redisribution here
204205
if args.prune == 'magnitude_variance':
206+
print('Using magnitude-variance pruning. Switching to Adam optimizer...')
205207
args.prune = magnitude_variance_pruning
206208
args.optimizer = 'adam'
207-
args.lr /= 100.0
208209
if args.redistribution == 'variance':
210+
print('Using variance redistribution. Switching to Adam optimizer...')
209211
args.redistribution = variance_redistribution
210212
args.optimizer = 'adam'
211-
args.lr /= 100.0
212213

213214
optimizer = None
214215
if args.optimizer == 'sgd':

0 commit comments

Comments
 (0)