Skip to content

Commit e655ab6

Browse files
committed
Remove unused channels w/ Keras Surgeon
1 parent a49a113 commit e655ab6

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
from tensorflow.keras.models import Sequential
3+
from tensorflow.keras.layers import Dense, Conv2D
4+
from hls4ml.optimization.keras.utils import get_last_layer_with_weights
5+
6+
'''
7+
Function for removing zero neurons & filters from a model and rewiring the model graph
8+
This function is built on top of Keras Surgeon available at: https://github.com/BenWhetton/keras-surgeon
9+
Keras Surgeon is no longer under active development and does not work for TensorFlow 2.3+ and QKeras
10+
The baseline version was forked and updated, available at: https://github.com/bo3z/keras-surgeon
11+
12+
Args:
13+
- model (keras.model): Input model
14+
15+
Return:
16+
- reduced (keras.model): Modified model, with redundant structures removed
17+
18+
'''
19+
def reduce_model(model):
20+
# TODO - Should we make Keras Surgeon a hard requirement in setup.cfg? If so, needs to be installed from git, @bo3z fork
21+
try:
22+
from kerassurgeon import Surgeon
23+
except ModuleNotFoundError:
24+
raise Exception('Keras Surgeon not installed. Unable to reduce model footprint '\
25+
'Please install up-to-date Keras Surgeon compatible wit TensorFlow 2.3+ and QKeras '\
26+
'Installation from git: https://github.com/bo3z/keras-surgeon')
27+
28+
# Initiate surgeon
29+
surgeon = Surgeon(model)
30+
31+
# Iterate through layers and identify neurons (columns) and filters (tensors, W x H x C) to be removed
32+
last_idx = get_last_layer_with_weights(model)
33+
for idx, layer in enumerate(model.layers):
34+
# Last layer with weights cannot be removed, as it maps to data set labels
35+
if (idx == last_idx):
36+
break
37+
38+
# Currently supported Dense and Conv2D; these two can be combined in a single if-statement
39+
# Keras Surgeon has a full range of support for Conv1D / Conv3D, reucurrent etc. - might extend in the future
40+
if isinstance(layer, Dense):
41+
weights = layer.get_weights()[0]
42+
zeros = np.where(~weights.any(axis=0))[0].tolist()
43+
surgeon.add_job('delete_channels', layer, channels=zeros)
44+
45+
elif isinstance(layer, Conv2D):
46+
weights = layer.get_weights()[0]
47+
zeros = np.where(~weights.reshape(-1, weights.shape[-1]).any(axis=0))[0].tolist()
48+
surgeon.add_job('delete_channels', layer, channels=zeros)
49+
50+
# Reduce model
51+
reduced = surgeon.operate()
52+
53+
# By default, Keras surgeon returns a Functional model
54+
# If the original was a Sequential, convert back
55+
is_sequential = model.__class__.__name__ == 'Sequential'
56+
if is_sequential:
57+
return Sequential(layers=reduced.layers)
58+
else:
59+
return reduced
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from qkeras import quantized_bits
2+
from qkeras import QDense, QActivation, QConv2D
3+
from tensorflow.keras.models import Sequential
4+
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Softmax, BatchNormalization, ReLU, Flatten, AveragePooling2D
5+
from hls4ml.optimization.keras.reduction import reduce_model
6+
from hls4ml.optimization.keras.utils import get_model_sparsity
7+
8+
'''
9+
Set some neurons / filters to zero and verify that these are removed
10+
Even is some neurons (columns) in the output layer are zero, these should not be removed (to match data set labels)
11+
Test verify the above property, by setting some zeros in the last layer and verifying these remain in place
12+
'''
13+
14+
def test_keras_model_reduction():
15+
model = Sequential()
16+
model.add(Conv2D(8, (3, 3), input_shape=(64, 64, 1), name='conv2d_1', padding='same'))
17+
model.add(MaxPooling2D())
18+
model.add(BatchNormalization())
19+
model.add(ReLU())
20+
model.add(Conv2D(32, (5, 5), padding='same', name='conv2d_2'))
21+
model.add(AveragePooling2D())
22+
model.add(BatchNormalization())
23+
model.add(ReLU())
24+
model.add(Flatten())
25+
model.add(Dense(32, input_shape=(16, ), name = 'dense_1', activation='relu'))
26+
model.add(BatchNormalization())
27+
model.add(Dense(14, name = 'dense_2', activation='relu'))
28+
model.add(BatchNormalization())
29+
model.add(Dense(5, name = 'dense_3'))
30+
model.add(Softmax())
31+
32+
indices = {
33+
'conv2d_1': [2, 4, 7],
34+
'conv2d_2': [0, 1, 2, 3, 4, 5],
35+
'dense_1': [0, 5, 17, 28],
36+
'dense_2': [1, 9, 4],
37+
'dense_3': [3],
38+
}
39+
for layer in model.layers:
40+
if isinstance(layer, Dense):
41+
weights = layer.get_weights()
42+
weights[0][:, indices[layer.name]] = 0
43+
layer.set_weights(weights)
44+
if isinstance(layer, Conv2D):
45+
weights = layer.get_weights()
46+
weights[0][:, :, :, indices[layer.name]] = 0
47+
layer.set_weights(weights)
48+
49+
sparsity, _ = get_model_sparsity(model)
50+
assert(sparsity > 0)
51+
52+
reduced = reduce_model(model)
53+
assert(reduced.get_layer('conv2d_1').get_weights()[0].shape == (3, 3, 1, 5))
54+
assert(reduced.get_layer('conv2d_2').get_weights()[0].shape == (5, 5, 5, 26))
55+
assert(reduced.get_layer('dense_1').get_weights()[0].shape == (6656, 28))
56+
assert(reduced.get_layer('dense_2').get_weights()[0].shape == (28, 11))
57+
assert(reduced.get_layer('dense_3').get_weights()[0].shape == (11, 5))
58+
59+
_, layer_sparsity = get_model_sparsity(reduced)
60+
assert(layer_sparsity['conv2d_1'] == 0)
61+
assert(layer_sparsity['conv2d_2'] == 0)
62+
assert(layer_sparsity['dense_1'] == 0)
63+
assert(layer_sparsity['dense_2'] == 0)
64+
assert(layer_sparsity['dense_3'] > 0)
65+
66+
def test_qkeras_model_reduction():
67+
bits = 8
68+
activation = 'quantized_relu(4)'
69+
quantizer = quantized_bits(bits, 0)
70+
71+
model = Sequential()
72+
model.add(QConv2D(8, (3, 3), input_shape=(64, 64, 1), name='qconv2d_1', padding='same', kernel_quantizer=quantizer))
73+
model.add(MaxPooling2D())
74+
model.add(BatchNormalization())
75+
model.add(QActivation(activation, name='qrelu_1'))
76+
model.add(QConv2D(32, (5, 5), padding='same', name='qconv2d_2', kernel_quantizer=quantizer))
77+
model.add(AveragePooling2D())
78+
model.add(BatchNormalization())
79+
model.add(QActivation(activation, name='qrelu_2'))
80+
model.add(Flatten())
81+
model.add(QDense(32, input_shape=(16, ), name = 'qdense_1', kernel_quantizer=quantizer))
82+
model.add(QActivation(activation, name='qrelu_3'))
83+
model.add(BatchNormalization())
84+
model.add(QDense(14, name = 'qdense_2', kernel_quantizer=quantizer))
85+
model.add(QActivation(activation, name='qrelu_4'))
86+
model.add(BatchNormalization())
87+
model.add(QDense(5, name = 'qdense_3', kernel_quantizer=quantizer))
88+
model.add(Softmax())
89+
90+
indices = {
91+
'qconv2d_1': [2, 4, 7],
92+
'qconv2d_2': [0, 1, 2, 3, 4, 5],
93+
'qdense_1': [0, 5, 17, 28],
94+
'qdense_2': [1, 9, 4],
95+
'qdense_3': [3],
96+
}
97+
for layer in model.layers:
98+
if isinstance(layer, QDense):
99+
weights = layer.get_weights()
100+
weights[0][:, indices[layer.name]] = 0
101+
layer.set_weights(weights)
102+
if isinstance(layer, QConv2D):
103+
weights = layer.get_weights()
104+
weights[0][:, :, :, indices[layer.name]] = 0
105+
layer.set_weights(weights)
106+
107+
sparsity, _ = get_model_sparsity(model)
108+
assert(sparsity > 0)
109+
110+
reduced = reduce_model(model)
111+
assert(reduced.get_layer('qconv2d_1').get_weights()[0].shape == (3, 3, 1, 5))
112+
assert(reduced.get_layer('qconv2d_2').get_weights()[0].shape == (5, 5, 5, 26))
113+
assert(reduced.get_layer('qdense_1').get_weights()[0].shape == (6656, 28))
114+
assert(reduced.get_layer('qdense_2').get_weights()[0].shape == (28, 11))
115+
assert(reduced.get_layer('qdense_3').get_weights()[0].shape == (11, 5))
116+
117+
_, layer_sparsity = get_model_sparsity(reduced)
118+
assert(layer_sparsity['qconv2d_1'] == 0)
119+
assert(layer_sparsity['qconv2d_2'] == 0)
120+
assert(layer_sparsity['qdense_1'] == 0)
121+
assert(layer_sparsity['qdense_2'] == 0)
122+
assert(layer_sparsity['qdense_3'] > 0)
123+
124+
# Verify network surgery has no impact on quantization
125+
assert(isinstance(reduced.get_layer('qrelu_1'), QActivation))
126+
assert(isinstance(reduced.get_layer('qrelu_2'), QActivation))
127+
assert(isinstance(reduced.get_layer('qrelu_3'), QActivation))
128+
assert(isinstance(reduced.get_layer('qrelu_4'), QActivation))
129+
assert(reduced.get_layer('qconv2d_1').kernel_quantizer['config']['bits'] == bits)
130+
assert(reduced.get_layer('qconv2d_2').kernel_quantizer['config']['bits'] == bits)
131+
assert(reduced.get_layer('qdense_1').kernel_quantizer['config']['bits'] == bits)
132+
assert(reduced.get_layer('qdense_2').kernel_quantizer['config']['bits'] == bits)
133+
assert(reduced.get_layer('qdense_3').kernel_quantizer['config']['bits'] == bits)
134+

0 commit comments

Comments
 (0)