22
22
from tensorflow_addons .utils import types
23
23
24
24
25
- @tf .keras .utils .register_keras_serializable (package = ' Addons' )
25
+ @tf .keras .utils .register_keras_serializable (package = " Addons" )
26
26
class GroupNormalization (tf .keras .layers .Layer ):
27
27
"""Group normalization layer.
28
28
@@ -71,19 +71,21 @@ class GroupNormalization(tf.keras.layers.Layer):
71
71
"""
72
72
73
73
@typechecked
74
- def __init__ (self ,
75
- groups : int = 2 ,
76
- axis : int = - 1 ,
77
- epsilon : int = 1e-3 ,
78
- center : bool = True ,
79
- scale : bool = True ,
80
- beta_initializer : types .Initializer = 'zeros' ,
81
- gamma_initializer : types .Initializer = 'ones' ,
82
- beta_regularizer : types .Regularizer = None ,
83
- gamma_regularizer : types .Regularizer = None ,
84
- beta_constraint : types .Constraint = None ,
85
- gamma_constraint : types .Constraint = None ,
86
- ** kwargs ):
74
+ def __init__ (
75
+ self ,
76
+ groups : int = 2 ,
77
+ axis : int = - 1 ,
78
+ epsilon : int = 1e-3 ,
79
+ center : bool = True ,
80
+ scale : bool = True ,
81
+ beta_initializer : types .Initializer = "zeros" ,
82
+ gamma_initializer : types .Initializer = "ones" ,
83
+ beta_regularizer : types .Regularizer = None ,
84
+ gamma_regularizer : types .Regularizer = None ,
85
+ beta_constraint : types .Constraint = None ,
86
+ gamma_constraint : types .Constraint = None ,
87
+ ** kwargs
88
+ ):
87
89
super ().__init__ (** kwargs )
88
90
self .supports_masking = True
89
91
self .groups = groups
@@ -117,39 +119,32 @@ def call(self, inputs):
117
119
tensor_input_shape = tf .shape (inputs )
118
120
119
121
reshaped_inputs , group_shape = self ._reshape_into_groups (
120
- inputs , input_shape , tensor_input_shape )
122
+ inputs , input_shape , tensor_input_shape
123
+ )
121
124
122
- normalized_inputs = self ._apply_normalization (reshaped_inputs ,
123
- input_shape )
125
+ normalized_inputs = self ._apply_normalization (reshaped_inputs , input_shape )
124
126
125
127
outputs = tf .reshape (normalized_inputs , tensor_input_shape )
126
128
127
129
return outputs
128
130
129
131
def get_config (self ):
130
132
config = {
131
- 'groups' :
132
- self .groups ,
133
- 'axis' :
134
- self .axis ,
135
- 'epsilon' :
136
- self .epsilon ,
137
- 'center' :
138
- self .center ,
139
- 'scale' :
140
- self .scale ,
141
- 'beta_initializer' :
142
- tf .keras .initializers .serialize (self .beta_initializer ),
143
- 'gamma_initializer' :
144
- tf .keras .initializers .serialize (self .gamma_initializer ),
145
- 'beta_regularizer' :
146
- tf .keras .regularizers .serialize (self .beta_regularizer ),
147
- 'gamma_regularizer' :
148
- tf .keras .regularizers .serialize (self .gamma_regularizer ),
149
- 'beta_constraint' :
150
- tf .keras .constraints .serialize (self .beta_constraint ),
151
- 'gamma_constraint' :
152
- tf .keras .constraints .serialize (self .gamma_constraint )
133
+ "groups" : self .groups ,
134
+ "axis" : self .axis ,
135
+ "epsilon" : self .epsilon ,
136
+ "center" : self .center ,
137
+ "scale" : self .scale ,
138
+ "beta_initializer" : tf .keras .initializers .serialize (self .beta_initializer ),
139
+ "gamma_initializer" : tf .keras .initializers .serialize (
140
+ self .gamma_initializer
141
+ ),
142
+ "beta_regularizer" : tf .keras .regularizers .serialize (self .beta_regularizer ),
143
+ "gamma_regularizer" : tf .keras .regularizers .serialize (
144
+ self .gamma_regularizer
145
+ ),
146
+ "beta_constraint" : tf .keras .constraints .serialize (self .beta_constraint ),
147
+ "gamma_constraint" : tf .keras .constraints .serialize (self .gamma_constraint ),
153
148
}
154
149
base_config = super ().get_config ()
155
150
return {** base_config , ** config }
@@ -174,7 +169,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
174
169
group_reduction_axes .pop (axis )
175
170
176
171
mean , variance = tf .nn .moments (
177
- reshaped_inputs , group_reduction_axes , keepdims = True )
172
+ reshaped_inputs , group_reduction_axes , keepdims = True
173
+ )
178
174
179
175
gamma , beta = self ._get_reshaped_weights (input_shape )
180
176
normalized_inputs = tf .nn .batch_normalization (
@@ -183,7 +179,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
183
179
variance = variance ,
184
180
scale = gamma ,
185
181
offset = beta ,
186
- variance_epsilon = self .epsilon )
182
+ variance_epsilon = self .epsilon ,
183
+ )
187
184
return normalized_inputs
188
185
189
186
def _get_reshaped_weights (self , input_shape ):
@@ -200,10 +197,11 @@ def _get_reshaped_weights(self, input_shape):
200
197
def _check_if_input_shape_is_none (self , input_shape ):
201
198
dim = input_shape [self .axis ]
202
199
if dim is None :
203
- raise ValueError ('Axis ' + str (self .axis ) + ' of '
204
- 'input tensor should have a defined dimension '
205
- 'but the layer received an input with shape ' +
206
- str (input_shape ) + '.' )
200
+ raise ValueError (
201
+ "Axis " + str (self .axis ) + " of "
202
+ "input tensor should have a defined dimension "
203
+ "but the layer received an input with shape " + str (input_shape ) + "."
204
+ )
207
205
208
206
def _set_number_of_groups_for_instance_norm (self , input_shape ):
209
207
dim = input_shape [self .axis ]
@@ -216,26 +214,30 @@ def _check_size_of_dimensions(self, input_shape):
216
214
dim = input_shape [self .axis ]
217
215
if dim < self .groups :
218
216
raise ValueError (
219
- 'Number of groups (' + str (self .groups ) + ') cannot be '
220
- 'more than the number of channels (' + str (dim ) + ').' )
217
+ "Number of groups (" + str (self .groups ) + ") cannot be "
218
+ "more than the number of channels (" + str (dim ) + ")."
219
+ )
221
220
222
221
if dim % self .groups != 0 :
223
222
raise ValueError (
224
- 'Number of groups (' + str (self .groups ) + ') must be a '
225
- 'multiple of the number of channels (' + str (dim ) + ').' )
223
+ "Number of groups (" + str (self .groups ) + ") must be a "
224
+ "multiple of the number of channels (" + str (dim ) + ")."
225
+ )
226
226
227
227
def _check_axis (self ):
228
228
229
229
if self .axis == 0 :
230
230
raise ValueError (
231
231
"You are trying to normalize your batch axis. Do you want to "
232
- "use tf.layer.batch_normalization instead" )
232
+ "use tf.layer.batch_normalization instead"
233
+ )
233
234
234
235
def _create_input_spec (self , input_shape ):
235
236
236
237
dim = input_shape [self .axis ]
237
238
self .input_spec = tf .keras .layers .InputSpec (
238
- ndim = len (input_shape ), axes = {self .axis : dim })
239
+ ndim = len (input_shape ), axes = {self .axis : dim }
240
+ )
239
241
240
242
def _add_gamma_weight (self , input_shape ):
241
243
@@ -245,10 +247,11 @@ def _add_gamma_weight(self, input_shape):
245
247
if self .scale :
246
248
self .gamma = self .add_weight (
247
249
shape = shape ,
248
- name = ' gamma' ,
250
+ name = " gamma" ,
249
251
initializer = self .gamma_initializer ,
250
252
regularizer = self .gamma_regularizer ,
251
- constraint = self .gamma_constraint )
253
+ constraint = self .gamma_constraint ,
254
+ )
252
255
else :
253
256
self .gamma = None
254
257
@@ -260,10 +263,11 @@ def _add_beta_weight(self, input_shape):
260
263
if self .center :
261
264
self .beta = self .add_weight (
262
265
shape = shape ,
263
- name = ' beta' ,
266
+ name = " beta" ,
264
267
initializer = self .beta_initializer ,
265
268
regularizer = self .beta_regularizer ,
266
- constraint = self .beta_constraint )
269
+ constraint = self .beta_constraint ,
270
+ )
267
271
else :
268
272
self .beta = None
269
273
@@ -274,7 +278,7 @@ def _create_broadcast_shape(self, input_shape):
274
278
return broadcast_shape
275
279
276
280
277
- @tf .keras .utils .register_keras_serializable (package = ' Addons' )
281
+ @tf .keras .utils .register_keras_serializable (package = " Addons" )
278
282
class InstanceNormalization (GroupNormalization ):
279
283
"""Instance normalization layer.
280
284
0 commit comments