2929from tensorflow_model_optimization .python .core .quantization .keras .experimental .default_n_bit import default_n_bit_transforms
3030from tensorflow_model_optimization .python .core .quantization .keras .graph_transformations import model_transformer
3131from tensorflow_model_optimization .python .core .quantization .keras .layers import conv_batchnorm_test_utils
32+ from tensorflow_model_optimization .python .core .quantization .keras .layers import dense_batchnorm_test_utils
3233
3334ModelTransformer = model_transformer .ModelTransformer
3435
3536Conv2DModel = conv_batchnorm_test_utils .Conv2DModel
3637DepthwiseConv2DModel = conv_batchnorm_test_utils .DepthwiseConv2DModel
38+ DenseModel = dense_batchnorm_test_utils .DenseModel
3739
3840keras = tf .keras
3941
@@ -73,21 +75,26 @@ def _get_model(
7375 post_bn_activation = activation ,
7476 squeeze_type = squeeze_type ,
7577 normalization_type = normalization_type )
78+ elif layer_type == 'Dense' :
79+ return DenseModel .get_nonfolded_batchnorm_model (
80+ post_bn_activation = activation , normalization_type = normalization_type )
7681
7782 def _get_input_shape (self , layer_type ):
7883 if layer_type == 'Conv2D' :
7984 return Conv2DModel .get_batched_input_shape ()
8085 elif layer_type == 'DepthwiseConv2D' :
8186 return DepthwiseConv2DModel .get_batched_input_shape ()
87+ elif layer_type == 'Dense' :
88+ return DenseModel .get_batched_input_shape ()
8289
83- def _test_conv_squeeze_bn_activation_transform (
90+ def _test_conv_squeeze_or_dense_bn_activation_transform (
8491 self ,
8592 layer_type ,
8693 squeeze_type ,
8794 normalization_type ,
8895 activation_type ,
8996 transform_class ,
90- conv_activation_class ,
97+ conv_or_dense_activation_class ,
9198 normalization_quantize_config_class ):
9299 model = self ._get_model (layer_type ,
93100 squeeze_type ,
@@ -107,7 +114,7 @@ def _test_conv_squeeze_bn_activation_transform(
107114 bn_layer = transformed_model .layers [2 ]
108115
109116 self .assertIsInstance (
110- conv_layer .activation , conv_activation_class )
117+ conv_layer .activation , conv_or_dense_activation_class )
111118 self .assertIsInstance (
112119 updated_metadata .get (bn_layer .name ).get ('quantize_config' ),
113120 normalization_quantize_config_class )
@@ -123,13 +130,13 @@ def _test_conv_squeeze_bn_activation_transform(
123130 ('DepthwiseConv2D' , 'SyncBatchNormalization' ),
124131 )
125132 def testConv2DBatchNormQuantize (self , layer_type , normalization_type ):
126- self ._test_conv_squeeze_bn_activation_transform (
133+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
127134 layer_type = layer_type ,
128135 squeeze_type = None ,
129136 normalization_type = normalization_type ,
130137 activation_type = None ,
131138 transform_class = default_n_bit_transforms .Conv2DBatchNormQuantize ,
132- conv_activation_class = quantize_aware_activation .NoOpActivation ,
139+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
133140 normalization_quantize_config_class =
134141 n_bit_configs .DefaultNBitOutputQuantizeConfig )
135142
@@ -140,14 +147,14 @@ def testConv2DBatchNormQuantize(self, layer_type, normalization_type):
140147 ('DepthwiseConv2D' , 'SyncBatchNormalization' ),
141148 )
142149 def testConv2DBatchNormReLUQuantize (self , layer_type , normalization_type ):
143- self ._test_conv_squeeze_bn_activation_transform (
150+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
144151 layer_type = layer_type ,
145152 squeeze_type = None ,
146153 normalization_type = normalization_type ,
147154 activation_type = 'relu' ,
148155 transform_class =
149156 default_n_bit_transforms .Conv2DBatchNormReLUQuantize ,
150- conv_activation_class = quantize_aware_activation .NoOpActivation ,
157+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
151158 normalization_quantize_config_class =
152159 n_bit_configs .NoOpQuantizeConfig )
153160
@@ -159,14 +166,14 @@ def testConv2DBatchNormReLUQuantize(self, layer_type, normalization_type):
159166 )
160167 def testConv2DBatchNormActivationQuantize (
161168 self , layer_type , normalization_type ):
162- self ._test_conv_squeeze_bn_activation_transform (
169+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
163170 layer_type = layer_type ,
164171 squeeze_type = None ,
165172 normalization_type = normalization_type ,
166173 activation_type = 'act_relu' ,
167174 transform_class =
168175 default_n_bit_transforms .Conv2DBatchNormActivationQuantize ,
169- conv_activation_class = quantize_aware_activation .NoOpActivation ,
176+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
170177 normalization_quantize_config_class =
171178 n_bit_configs .NoOpQuantizeConfig )
172179
@@ -178,14 +185,14 @@ def testConv2DBatchNormActivationQuantize(
178185 )
179186 def testConv2DReshapeBatchNormQuantize (
180187 self , layer_type , normalization_type ):
181- self ._test_conv_squeeze_bn_activation_transform (
188+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
182189 layer_type = layer_type ,
183190 squeeze_type = 'sepconv1d_squeeze' ,
184191 normalization_type = normalization_type ,
185192 activation_type = False ,
186193 transform_class =
187194 default_n_bit_transforms .Conv2DReshapeBatchNormQuantize ,
188- conv_activation_class = quantize_aware_activation .NoOpActivation ,
195+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
189196 normalization_quantize_config_class =
190197 n_bit_configs .DefaultNBitOutputQuantizeConfig )
191198
@@ -197,14 +204,14 @@ def testConv2DReshapeBatchNormQuantize(
197204 )
198205 def testConv2DReshapeBatchNormReLUQuantize (
199206 self , layer_type , normalization_type ):
200- self ._test_conv_squeeze_bn_activation_transform (
207+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
201208 layer_type = layer_type ,
202209 squeeze_type = 'sepconv1d_squeeze' ,
203210 normalization_type = normalization_type ,
204211 activation_type = 'relu' ,
205212 transform_class =
206213 default_n_bit_transforms .Conv2DReshapeBatchNormReLUQuantize ,
207- conv_activation_class = quantize_aware_activation .NoOpActivation ,
214+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
208215 normalization_quantize_config_class =
209216 n_bit_configs .NoOpQuantizeConfig )
210217
@@ -216,17 +223,64 @@ def testConv2DReshapeBatchNormReLUQuantize(
216223 )
217224 def testConv2DReshapeBatchNormActivationQuantize (
218225 self , layer_type , normalization_type ):
219- self ._test_conv_squeeze_bn_activation_transform (
226+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
220227 layer_type = layer_type ,
221228 squeeze_type = 'sepconv1d_squeeze' ,
222229 normalization_type = normalization_type ,
223230 activation_type = 'act_relu' ,
224231 transform_class =
225232 default_n_bit_transforms .Conv2DReshapeBatchNormActivationQuantize ,
226- conv_activation_class = quantize_aware_activation .NoOpActivation ,
233+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
227234 normalization_quantize_config_class =
228235 n_bit_configs .NoOpQuantizeConfig )
229236
237+ @parameterized .parameters (
238+ ('Dense' , 'BatchNormalization' ),
239+ ('Dense' , 'SyncBatchNormalization' ),
240+ )
241+ def testDenseBatchNormQuantize (self , layer_type , normalization_type ):
242+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
243+ layer_type = layer_type ,
244+ squeeze_type = None ,
245+ normalization_type = normalization_type ,
246+ activation_type = None ,
247+ transform_class = default_n_bit_transforms .DenseBatchNormQuantize ,
248+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
249+ normalization_quantize_config_class = n_bit_configs
250+ .DefaultNBitOutputQuantizeConfig )
251+
252+ @parameterized .parameters (
253+ ('Dense' , 'BatchNormalization' ),
254+ ('Dense' , 'SyncBatchNormalization' ),
255+ )
256+ def testDenseBatchNormReLUQuantize (self , layer_type , normalization_type ):
257+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
258+ layer_type = layer_type ,
259+ squeeze_type = None ,
260+ normalization_type = normalization_type ,
261+ activation_type = 'relu' ,
262+ transform_class = default_n_bit_transforms .DenseBatchNormReLUQuantize ,
263+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
264+ normalization_quantize_config_class = n_bit_configs
265+ .NoOpQuantizeConfig )
266+
267+ @parameterized .parameters (
268+ ('Dense' , 'BatchNormalization' ),
269+ ('Dense' , 'SyncBatchNormalization' ),
270+ )
271+ def testDenseBatchNormActivationQuantize (self , layer_type ,
272+ normalization_type ):
273+ self ._test_conv_squeeze_or_dense_bn_activation_transform (
274+ layer_type = layer_type ,
275+ squeeze_type = None ,
276+ normalization_type = normalization_type ,
277+ activation_type = 'act_relu' ,
278+ transform_class = default_n_bit_transforms
279+ .DenseBatchNormActivationQuantize ,
280+ conv_or_dense_activation_class = quantize_aware_activation .NoOpActivation ,
281+ normalization_quantize_config_class = n_bit_configs
282+ .NoOpQuantizeConfig )
283+
230284 @parameterized .named_parameters (
231285 ('padding_valid' , {'padding' : 'valid' }),
232286 ('padding_same' , {'padding' : 'same' }),
0 commit comments