Skip to content

Commit 4bede42

Browse files
authored
run black normalizations and wrappers (tensorflow#1061)
* black layers fun * modify pyproject
1 parent ecba851 commit 4bede42

File tree

5 files changed

+176
-170
lines changed

5 files changed

+176
-170
lines changed

pyproject.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ exclude = '''
1414
| build
1515
| dist
1616
)/
17-
| tensorflow_addons/layers/normalizations.py
18-
| tensorflow_addons/layers/normalizations_test.py
19-
| tensorflow_addons/layers/wrappers.py
20-
| tensorflow_addons/layers/wrappers_test.py
2117
| tensorflow_addons/losses/__init__.py
2218
| tensorflow_addons/losses/focal_loss.py
2319
| tensorflow_addons/losses/giou_loss.py

tensorflow_addons/layers/normalizations.py

Lines changed: 60 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tensorflow_addons.utils import types
2323

2424

25-
@tf.keras.utils.register_keras_serializable(package='Addons')
25+
@tf.keras.utils.register_keras_serializable(package="Addons")
2626
class GroupNormalization(tf.keras.layers.Layer):
2727
"""Group normalization layer.
2828
@@ -71,19 +71,21 @@ class GroupNormalization(tf.keras.layers.Layer):
7171
"""
7272

7373
@typechecked
74-
def __init__(self,
75-
groups: int = 2,
76-
axis: int = -1,
77-
epsilon: int = 1e-3,
78-
center: bool = True,
79-
scale: bool = True,
80-
beta_initializer: types.Initializer = 'zeros',
81-
gamma_initializer: types.Initializer = 'ones',
82-
beta_regularizer: types.Regularizer = None,
83-
gamma_regularizer: types.Regularizer = None,
84-
beta_constraint: types.Constraint = None,
85-
gamma_constraint: types.Constraint = None,
86-
**kwargs):
74+
def __init__(
75+
self,
76+
groups: int = 2,
77+
axis: int = -1,
78+
epsilon: int = 1e-3,
79+
center: bool = True,
80+
scale: bool = True,
81+
beta_initializer: types.Initializer = "zeros",
82+
gamma_initializer: types.Initializer = "ones",
83+
beta_regularizer: types.Regularizer = None,
84+
gamma_regularizer: types.Regularizer = None,
85+
beta_constraint: types.Constraint = None,
86+
gamma_constraint: types.Constraint = None,
87+
**kwargs
88+
):
8789
super().__init__(**kwargs)
8890
self.supports_masking = True
8991
self.groups = groups
@@ -117,39 +119,32 @@ def call(self, inputs):
117119
tensor_input_shape = tf.shape(inputs)
118120

119121
reshaped_inputs, group_shape = self._reshape_into_groups(
120-
inputs, input_shape, tensor_input_shape)
122+
inputs, input_shape, tensor_input_shape
123+
)
121124

122-
normalized_inputs = self._apply_normalization(reshaped_inputs,
123-
input_shape)
125+
normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
124126

125127
outputs = tf.reshape(normalized_inputs, tensor_input_shape)
126128

127129
return outputs
128130

129131
def get_config(self):
130132
config = {
131-
'groups':
132-
self.groups,
133-
'axis':
134-
self.axis,
135-
'epsilon':
136-
self.epsilon,
137-
'center':
138-
self.center,
139-
'scale':
140-
self.scale,
141-
'beta_initializer':
142-
tf.keras.initializers.serialize(self.beta_initializer),
143-
'gamma_initializer':
144-
tf.keras.initializers.serialize(self.gamma_initializer),
145-
'beta_regularizer':
146-
tf.keras.regularizers.serialize(self.beta_regularizer),
147-
'gamma_regularizer':
148-
tf.keras.regularizers.serialize(self.gamma_regularizer),
149-
'beta_constraint':
150-
tf.keras.constraints.serialize(self.beta_constraint),
151-
'gamma_constraint':
152-
tf.keras.constraints.serialize(self.gamma_constraint)
133+
"groups": self.groups,
134+
"axis": self.axis,
135+
"epsilon": self.epsilon,
136+
"center": self.center,
137+
"scale": self.scale,
138+
"beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
139+
"gamma_initializer": tf.keras.initializers.serialize(
140+
self.gamma_initializer
141+
),
142+
"beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
143+
"gamma_regularizer": tf.keras.regularizers.serialize(
144+
self.gamma_regularizer
145+
),
146+
"beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
147+
"gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
153148
}
154149
base_config = super().get_config()
155150
return {**base_config, **config}
@@ -174,7 +169,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
174169
group_reduction_axes.pop(axis)
175170

176171
mean, variance = tf.nn.moments(
177-
reshaped_inputs, group_reduction_axes, keepdims=True)
172+
reshaped_inputs, group_reduction_axes, keepdims=True
173+
)
178174

179175
gamma, beta = self._get_reshaped_weights(input_shape)
180176
normalized_inputs = tf.nn.batch_normalization(
@@ -183,7 +179,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
183179
variance=variance,
184180
scale=gamma,
185181
offset=beta,
186-
variance_epsilon=self.epsilon)
182+
variance_epsilon=self.epsilon,
183+
)
187184
return normalized_inputs
188185

189186
def _get_reshaped_weights(self, input_shape):
@@ -200,10 +197,11 @@ def _get_reshaped_weights(self, input_shape):
200197
def _check_if_input_shape_is_none(self, input_shape):
201198
dim = input_shape[self.axis]
202199
if dim is None:
203-
raise ValueError('Axis ' + str(self.axis) + ' of '
204-
'input tensor should have a defined dimension '
205-
'but the layer received an input with shape ' +
206-
str(input_shape) + '.')
200+
raise ValueError(
201+
"Axis " + str(self.axis) + " of "
202+
"input tensor should have a defined dimension "
203+
"but the layer received an input with shape " + str(input_shape) + "."
204+
)
207205

208206
def _set_number_of_groups_for_instance_norm(self, input_shape):
209207
dim = input_shape[self.axis]
@@ -216,26 +214,30 @@ def _check_size_of_dimensions(self, input_shape):
216214
dim = input_shape[self.axis]
217215
if dim < self.groups:
218216
raise ValueError(
219-
'Number of groups (' + str(self.groups) + ') cannot be '
220-
'more than the number of channels (' + str(dim) + ').')
217+
"Number of groups (" + str(self.groups) + ") cannot be "
218+
"more than the number of channels (" + str(dim) + ")."
219+
)
221220

222221
if dim % self.groups != 0:
223222
raise ValueError(
224-
'Number of groups (' + str(self.groups) + ') must be a '
225-
'multiple of the number of channels (' + str(dim) + ').')
223+
"Number of groups (" + str(self.groups) + ") must be a "
224+
"multiple of the number of channels (" + str(dim) + ")."
225+
)
226226

227227
def _check_axis(self):
228228

229229
if self.axis == 0:
230230
raise ValueError(
231231
"You are trying to normalize your batch axis. Do you want to "
232-
"use tf.layer.batch_normalization instead")
232+
"use tf.layer.batch_normalization instead"
233+
)
233234

234235
def _create_input_spec(self, input_shape):
235236

236237
dim = input_shape[self.axis]
237238
self.input_spec = tf.keras.layers.InputSpec(
238-
ndim=len(input_shape), axes={self.axis: dim})
239+
ndim=len(input_shape), axes={self.axis: dim}
240+
)
239241

240242
def _add_gamma_weight(self, input_shape):
241243

@@ -245,10 +247,11 @@ def _add_gamma_weight(self, input_shape):
245247
if self.scale:
246248
self.gamma = self.add_weight(
247249
shape=shape,
248-
name='gamma',
250+
name="gamma",
249251
initializer=self.gamma_initializer,
250252
regularizer=self.gamma_regularizer,
251-
constraint=self.gamma_constraint)
253+
constraint=self.gamma_constraint,
254+
)
252255
else:
253256
self.gamma = None
254257

@@ -260,10 +263,11 @@ def _add_beta_weight(self, input_shape):
260263
if self.center:
261264
self.beta = self.add_weight(
262265
shape=shape,
263-
name='beta',
266+
name="beta",
264267
initializer=self.beta_initializer,
265268
regularizer=self.beta_regularizer,
266-
constraint=self.beta_constraint)
269+
constraint=self.beta_constraint,
270+
)
267271
else:
268272
self.beta = None
269273

@@ -274,7 +278,7 @@ def _create_broadcast_shape(self, input_shape):
274278
return broadcast_shape
275279

276280

277-
@tf.keras.utils.register_keras_serializable(package='Addons')
281+
@tf.keras.utils.register_keras_serializable(package="Addons")
278282
class InstanceNormalization(GroupNormalization):
279283
"""Instance normalization layer.
280284

0 commit comments

Comments
 (0)