Skip to content

add filter response normalization #765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
fdc77ab
add filter response normalization
AakashKumarNain Dec 13, 2019
542769c
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Dec 15, 2019
44b757b
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Dec 23, 2019
28d70bf
Merge branch 'master' into AakashKumarNain_add_frn_norm
gabrieldemarmiesse Feb 26, 2020
8887639
Merge branch 'master' into AakashKumarNain_add_frn_norm
gabrieldemarmiesse Feb 26, 2020
9c0d54a
Merge branch 'add_frn_norm' of https://github.com/AakashKumarNain/add…
AakashKumarNain Mar 9, 2020
bb00dbe
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Mar 10, 2020
3030d95
update FRN layer, tests still failing
AakashKumarNain Mar 10, 2020
7b60d00
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Mar 10, 2020
d278cfa
update test cases and set seed
AakashKumarNain Mar 10, 2020
567c2c2
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Mar 23, 2020
969c3a7
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Mar 25, 2020
61fb2da
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Mar 25, 2020
f907777
refactor code
AakashKumarNain Mar 25, 2020
f14511f
add serialization test
AakashKumarNain Mar 25, 2020
2e45832
bug fix in serialization
AakashKumarNain Mar 25, 2020
932656f
move epsilon weights to constructor
AakashKumarNain Mar 25, 2020
61fbe3d
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Mar 26, 2020
618c83f
remove extra checks, add TODO and add grads check for epsilon
AakashKumarNain Mar 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow_addons/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tensorflow_addons.layers.gelu import GELU
from tensorflow_addons.layers.maxout import Maxout
from tensorflow_addons.layers.multihead_attention import MultiHeadAttention
from tensorflow_addons.layers.normalizations import FilterResponseNormalization
from tensorflow_addons.layers.normalizations import GroupNormalization
from tensorflow_addons.layers.normalizations import InstanceNormalization
from tensorflow_addons.layers.optical_flow import CorrelationCost
Expand Down
201 changes: 201 additions & 0 deletions tensorflow_addons/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,204 @@ def __init__(self, **kwargs):

kwargs["groups"] = -1
super().__init__(**kwargs)


@tf.keras.utils.register_keras_serializable(package="Addons")
class FilterResponseNormalization(tf.keras.layers.Layer):
"""Filter response normalization layer.

Filter Response Normalization (FRN), a normalization
method that enables models trained with per-channel
normalization to achieve high accuracy. It performs better than
all other normalization techniques for small batches and is par
with Batch Normalization for bigger batch sizes.

Arguments
axis: List of axes that should be normalized. This should represent the
spatial dimensions.
epsilon: Small positive float value added to variance to avoid dividing by zero.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
learned_epsilon: (bool) Whether to add another learnable
epsilon parameter or not.
name: Optional name for the layer

Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model. This layer, as of now,
works on a 4-D tensor where the tensor should have the shape [N X H X W X C]

TODO: Add support for NCHW data format and FC layers.

Output shape
Same shape as input.

References
- [Filter Response Normalization Layer: Eliminating Batch Dependence
in the training of Deep Neural Networks]
(https://arxiv.org/abs/1911.09737)
"""

def __init__(
self,
epsilon: float = 1e-6,
axis: list = [1, 2],
beta_initializer: types.Initializer = "zeros",
gamma_initializer: types.Initializer = "ones",
beta_regularizer: types.Regularizer = None,
gamma_regularizer: types.Regularizer = None,
beta_constraint: types.Constraint = None,
gamma_constraint: types.Constraint = None,
learned_epsilon: bool = False,
learned_epsilon_constraint: types.Constraint = None,
name: str = None,
**kwargs
):
super().__init__(name=name, **kwargs)
self.epsilon = tf.math.abs(tf.cast(epsilon, dtype=self.dtype))
self.beta_initializer = tf.keras.initializers.get(beta_initializer)
self.gamma_initializer = tf.keras.initializers.get(gamma_initializer)
self.beta_regularizer = tf.keras.regularizers.get(beta_regularizer)
self.gamma_regularizer = tf.keras.regularizers.get(gamma_regularizer)
self.beta_constraint = tf.keras.constraints.get(beta_constraint)
self.gamma_constraint = tf.keras.constraints.get(gamma_constraint)
self.use_eps_learned = learned_epsilon
self.supports_masking = True

if self.use_eps_learned:
self.eps_learned_initializer = tf.keras.initializers.Constant(1e-4)
self.eps_learned_constraint = tf.keras.constraints.get(
learned_epsilon_constraint
)
self.eps_learned = self.add_weight(
shape=(1,),
name="learned_epsilon",
dtype=self.dtype,
initializer=tf.keras.initializers.get(self.eps_learned_initializer),
regularizer=None,
constraint=self.eps_learned_constraint,
)
else:
self.eps_learned_initializer = None
self.eps_learned_constraint = None

self._check_axis(axis)

def build(self, input_shape):
if len(tf.TensorShape(input_shape)) != 4:
raise ValueError(
"""Only 4-D tensors (CNNs) are supported
as of now."""
)
self._check_if_input_shape_is_none(input_shape)
self._create_input_spec(input_shape)
self._add_gamma_weight(input_shape)
self._add_beta_weight(input_shape)
super().build(input_shape)

def call(self, inputs):
epsilon = self.epsilon
if self.use_eps_learned:
epsilon += tf.math.abs(self.eps_learned)
nu2 = tf.reduce_mean(tf.square(inputs), axis=self.axis, keepdims=True)
normalized_inputs = inputs * tf.math.rsqrt(nu2 + epsilon)
return self.gamma * normalized_inputs + self.beta

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"axis": self.axis,
"epsilon": self.epsilon,
"learned_epsilon": self.use_eps_learned,
"beta_initializer": tf.keras.initializers.serialize(self.beta_initializer),
"gamma_initializer": tf.keras.initializers.serialize(
self.gamma_initializer
),
"beta_regularizer": tf.keras.regularizers.serialize(self.beta_regularizer),
"gamma_regularizer": tf.keras.regularizers.serialize(
self.gamma_regularizer
),
"beta_constraint": tf.keras.constraints.serialize(self.beta_constraint),
"gamma_constraint": tf.keras.constraints.serialize(self.gamma_constraint),
"learned_epsilon_constraint": tf.keras.constraints.serialize(
self.eps_learned_constraint
),
}
base_config = super().get_config()
return dict(**base_config, **config)

def _create_input_spec(self, input_shape):
ndims = len(tf.TensorShape(input_shape))
for idx, x in enumerate(self.axis):
if x < 0:
self.axis[idx] = ndims + x

# Validate axes
for x in self.axis:
if x < 0 or x >= ndims:
raise ValueError("Invalid axis: %d" % x)

if len(self.axis) != len(set(self.axis)):
raise ValueError("Duplicate axis: %s" % self.axis)

axis_to_dim = {x: input_shape[x] for x in self.axis}
self.input_spec = tf.keras.layers.InputSpec(ndim=ndims, axes=axis_to_dim)

def _check_axis(self, axis):
if not isinstance(axis, list):
raise TypeError(
"""Expected a list of values but got {}.""".format(type(axis))
)
else:
self.axis = axis

if self.axis != [1, 2]:
raise ValueError(
"""FilterResponseNormalization operates on per-channel basis.
Axis values should be a list of spatial dimensions."""
)

def _check_if_input_shape_is_none(self, input_shape):
dim1, dim2 = input_shape[self.axis[0]], input_shape[self.axis[1]]
if dim1 is None or dim2 is None:
raise ValueError(
"""Axis {} of input tensor should have a defined dimension but
the layer received an input with shape {}.""".format(
self.axis, input_shape
)
)

def _add_gamma_weight(self, input_shape):
# Get the channel dimension
dim = input_shape[-1]
shape = [1, 1, 1, dim]
# Initialize gamma with shape (1, 1, 1, C)
self.gamma = self.add_weight(
shape=shape,
name="gamma",
dtype=self.dtype,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
)

def _add_beta_weight(self, input_shape):
# Get the channel dimension
dim = input_shape[-1]
shape = [1, 1, 1, dim]
# Initialize beta with shape (1, 1, 1, C)
self.beta = self.add_weight(
shape=shape,
name="beta",
dtype=self.dtype,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
)
117 changes: 117 additions & 0 deletions tensorflow_addons/layers/normalizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import tensorflow as tf

from tensorflow_addons.layers.normalizations import FilterResponseNormalization
from tensorflow_addons.layers.normalizations import GroupNormalization
from tensorflow_addons.layers.normalizations import InstanceNormalization
from tensorflow_addons.utils import test_utils
Expand Down Expand Up @@ -331,5 +332,121 @@ def test_groupnorm_convnet_no_center_no_scale(self):
)


def calculate_frn(
x, beta=0.2, gamma=1, eps=1e-6, learned_epsilon=False, dtype=np.float32
):
if learned_epsilon:
eps = eps + 1e-4
eps = tf.cast(eps, dtype=dtype)
nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
x = x * tf.math.rsqrt(nu2 + tf.abs(eps))
return gamma * x + beta


def set_random_seed():
seed = 0x2020
np.random.seed(seed)
tf.random.set_seed(seed)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_with_beta(dtype):
set_random_seed()
inputs = np.random.rand(28, 28, 1).astype(dtype)
inputs = np.expand_dims(inputs, axis=0)
frn = FilterResponseNormalization(
beta_initializer="ones", gamma_initializer="ones", dtype=dtype
)
frn.build((None, 28, 28, 1))
observed = frn(inputs)
expected = calculate_frn(inputs, beta=1, gamma=1, dtype=dtype)
np.testing.assert_allclose(expected[0], observed[0])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_with_gamma(dtype):
set_random_seed()
inputs = np.random.rand(28, 28, 1).astype(dtype)
inputs = np.expand_dims(inputs, axis=0)
frn = FilterResponseNormalization(
beta_initializer="zeros", gamma_initializer="ones", dtype=dtype
)
frn.build((None, 28, 28, 1))
observed = frn(inputs)
expected = calculate_frn(inputs, beta=0, gamma=1, dtype=dtype)
np.testing.assert_allclose(expected[0], observed[0])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_with_epsilon(dtype):
set_random_seed()
inputs = np.random.rand(28, 28, 1).astype(dtype)
inputs = np.expand_dims(inputs, axis=0)
frn = FilterResponseNormalization(
beta_initializer=tf.keras.initializers.Constant(0.5),
gamma_initializer="ones",
learned_epsilon=True,
dtype=dtype,
)
frn.build((None, 28, 28, 1))
observed = frn(inputs)
expected = calculate_frn(
inputs, beta=0.5, gamma=1, learned_epsilon=True, dtype=dtype
)
np.testing.assert_allclose(expected[0], observed[0])


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_keras_model(dtype):
set_random_seed()
frn = FilterResponseNormalization(
beta_initializer="ones", gamma_initializer="ones", dtype=dtype
)
random_inputs = np.random.rand(10, 32, 32, 3).astype(dtype)
random_labels = np.random.randint(2, size=(10,)).astype(dtype)
input_layer = tf.keras.layers.Input(shape=(32, 32, 3))
x = frn(input_layer)
x = tf.keras.layers.Flatten()(x)
out = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.models.Model(input_layer, out)
model.compile(loss="binary_crossentropy", optimizer="sgd")
model.fit(random_inputs, random_labels, epochs=2)


@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_serialization(dtype):
frn = FilterResponseNormalization(
beta_initializer="ones", gamma_initializer="ones", dtype=dtype
)
serialized_frn = tf.keras.layers.serialize(frn)
new_layer = tf.keras.layers.deserialize(serialized_frn)
assert frn.get_config() == new_layer.get_config()


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_eps_gards(dtype):
set_random_seed()
random_inputs = np.random.rand(10, 32, 32, 3).astype(np.float32)
random_labels = np.random.randint(2, size=(10,)).astype(np.float32)
input_layer = tf.keras.layers.Input(shape=(32, 32, 3))
frn = FilterResponseNormalization(
beta_initializer="ones", gamma_initializer="ones", learned_epsilon=True
)
initial_eps_value = frn.eps_learned.numpy()[0]
x = frn(input_layer)
x = tf.keras.layers.Flatten()(x)
out = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.models.Model(input_layer, out)
model.compile(loss="binary_crossentropy", optimizer="sgd")
model.fit(random_inputs, random_labels, epochs=1)
final_eps_value = frn.eps_learned.numpy()[0]
assert initial_eps_value != final_eps_value


if __name__ == "__main__":
sys.exit(pytest.main([__file__]))