Skip to content

run black normalizations and wrappers #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ exclude = '''
| build
| dist
)/
| tensorflow_addons/layers/normalizations.py
| tensorflow_addons/layers/normalizations_test.py
| tensorflow_addons/layers/wrappers.py
| tensorflow_addons/layers/wrappers_test.py
| tensorflow_addons/losses/__init__.py
| tensorflow_addons/losses/focal_loss.py
| tensorflow_addons/losses/giou_loss.py
Expand Down
116 changes: 60 additions & 56 deletions tensorflow_addons/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tensorflow_addons.utils import types


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.keras.utils.register_keras_serializable(package="Addons")
class GroupNormalization(tf.keras.layers.Layer):
"""Group normalization layer.

Expand Down Expand Up @@ -71,19 +71,21 @@ class GroupNormalization(tf.keras.layers.Layer):
"""

@typechecked
def __init__(self,
groups: int = 2,
axis: int = -1,
epsilon: int = 1e-3,
center: bool = True,
scale: bool = True,
beta_initializer: types.Initializer = 'zeros',
gamma_initializer: types.Initializer = 'ones',
beta_regularizer: types.Regularizer = None,
gamma_regularizer: types.Regularizer = None,
beta_constraint: types.Constraint = None,
gamma_constraint: types.Constraint = None,
**kwargs):
def __init__(
self,
groups: int = 2,
axis: int = -1,
epsilon: int = 1e-3,
center: bool = True,
scale: bool = True,
beta_initializer: types.Initializer = "zeros",
gamma_initializer: types.Initializer = "ones",
beta_regularizer: types.Regularizer = None,
gamma_regularizer: types.Regularizer = None,
beta_constraint: types.Constraint = None,
gamma_constraint: types.Constraint = None,
**kwargs
):
super().__init__(**kwargs)
self.supports_masking = True
self.groups = groups
Expand Down Expand Up @@ -117,39 +119,32 @@ def call(self, inputs):
tensor_input_shape = tf.shape(inputs)

reshaped_inputs, group_shape = self._reshape_into_groups(
inputs, input_shape, tensor_input_shape)
inputs, input_shape, tensor_input_shape
)

normalized_inputs = self._apply_normalization(reshaped_inputs,
input_shape)
normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)

outputs = tf.reshape(normalized_inputs, tensor_input_shape)

return outputs

def get_config(self):
config = {
'groups':
self.groups,
'axis':
self.axis,
'epsilon':
self.epsilon,
'center':
self.center,
'scale':
self.scale,
'beta_initializer':
tf.keras.initializers.serialize(self.beta_initializer),
'gamma_initializer':
tf.keras.initializers.serialize(self.gamma_initializer),
'beta_regularizer':
tf.keras.regularizers.serialize(self.beta_regularizer),
'gamma_regularizer':
tf.keras.regularizers.serialize(self.gamma_regularizer),
'beta_constraint':
tf.keras.constraints.serialize(self.beta_constraint),
'gamma_constraint':
tf.keras.constraints.serialize(self.gamma_constraint)
"groups": self.groups,
"axis": self.axis,
"epsilon": self.epsilon,
"center": self.center,
"scale": self.scale,
"beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
"gamma_initializer": tf.keras.initializers.serialize(
self.gamma_initializer
),
"beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": tf.keras.regularizers.serialize(
self.gamma_regularizer
),
"beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
"gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
}
base_config = super().get_config()
return {**base_config, **config}
Expand All @@ -174,7 +169,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
group_reduction_axes.pop(axis)

mean, variance = tf.nn.moments(
reshaped_inputs, group_reduction_axes, keepdims=True)
reshaped_inputs, group_reduction_axes, keepdims=True
)

gamma, beta = self._get_reshaped_weights(input_shape)
normalized_inputs = tf.nn.batch_normalization(
Expand All @@ -183,7 +179,8 @@ def _apply_normalization(self, reshaped_inputs, input_shape):
variance=variance,
scale=gamma,
offset=beta,
variance_epsilon=self.epsilon)
variance_epsilon=self.epsilon,
)
return normalized_inputs

def _get_reshaped_weights(self, input_shape):
Expand All @@ -200,10 +197,11 @@ def _get_reshaped_weights(self, input_shape):
def _check_if_input_shape_is_none(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
raise ValueError(
"Axis " + str(self.axis) + " of "
"input tensor should have a defined dimension "
"but the layer received an input with shape " + str(input_shape) + "."
)

def _set_number_of_groups_for_instance_norm(self, input_shape):
dim = input_shape[self.axis]
Expand All @@ -216,26 +214,30 @@ def _check_size_of_dimensions(self, input_shape):
dim = input_shape[self.axis]
if dim < self.groups:
raise ValueError(
'Number of groups (' + str(self.groups) + ') cannot be '
'more than the number of channels (' + str(dim) + ').')
"Number of groups (" + str(self.groups) + ") cannot be "
"more than the number of channels (" + str(dim) + ")."
)

if dim % self.groups != 0:
raise ValueError(
'Number of groups (' + str(self.groups) + ') must be a '
'multiple of the number of channels (' + str(dim) + ').')
"Number of groups (" + str(self.groups) + ") must be a "
"multiple of the number of channels (" + str(dim) + ")."
)

def _check_axis(self):

if self.axis == 0:
raise ValueError(
"You are trying to normalize your batch axis. Do you want to "
"use tf.layer.batch_normalization instead")
"use tf.layer.batch_normalization instead"
)

def _create_input_spec(self, input_shape):

dim = input_shape[self.axis]
self.input_spec = tf.keras.layers.InputSpec(
ndim=len(input_shape), axes={self.axis: dim})
ndim=len(input_shape), axes={self.axis: dim}
)

def _add_gamma_weight(self, input_shape):

Expand All @@ -245,10 +247,11 @@ def _add_gamma_weight(self, input_shape):
if self.scale:
self.gamma = self.add_weight(
shape=shape,
name='gamma',
name="gamma",
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
constraint=self.gamma_constraint,
)
else:
self.gamma = None

Expand All @@ -260,10 +263,11 @@ def _add_beta_weight(self, input_shape):
if self.center:
self.beta = self.add_weight(
shape=shape,
name='beta',
name="beta",
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
constraint=self.beta_constraint,
)
else:
self.beta = None

Expand All @@ -274,7 +278,7 @@ def _create_broadcast_shape(self, input_shape):
return broadcast_shape


@tf.keras.utils.register_keras_serializable(package='Addons')
@tf.keras.utils.register_keras_serializable(package="Addons")
class InstanceNormalization(GroupNormalization):
"""Instance normalization layer.

Expand Down
Loading