Skip to content

group normalization layer test #766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 4, 2020
81 changes: 81 additions & 0 deletions tensorflow_addons/layers/normalizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _test_specific_layer(self, inputs, axis, groups, center, scale):

def _create_and_fit_Sequential_model(self, layer, shape):
# Helperfunction for quick evaluation
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
model.add(layer)
model.add(tf.keras.layers.Dense(32))
Expand Down Expand Up @@ -233,6 +234,7 @@ def test_regularizations(self):
def test_groupnorm_conv(self):
# Check if Axis is working for CONV nets
# Testing for 1 == LayerNorm, 5 == GroupNorm, -1 == InstanceNorm
np.random.seed(0x2020)
groups = [-1, 5, 1]
for i in groups:
model = tf.keras.models.Sequential()
Expand All @@ -246,6 +248,85 @@ def test_groupnorm_conv(self):
model.fit(x=x, y=y, epochs=1)
self.assertTrue(hasattr(model.layers[0], "gamma"))

def test_groupnorm_correctness_1d(self):
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
norm = GroupNormalization(input_shape=(10,), groups=2)
model.add(norm)
model.compile(loss="mse", optimizer="rmsprop")

x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= self.evaluate(norm.beta)
out /= self.evaluate(norm.gamma)

self.assertAllClose(out.mean(), 0.0, atol=1e-1)
self.assertAllClose(out.std(), 1.0, atol=1e-1)

def test_groupnorm_2d_different_groups(self):
np.random.seed(0x2020)
groups = [2, 1, 10]
for i in groups:
model = tf.keras.models.Sequential()
norm = GroupNormalization(axis=1, groups=i, input_shape=(10, 3))
model.add(norm)
# centered and variance are 5.0 and 10.0, respectively
model.compile(loss="mse", optimizer="rmsprop")
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 3))
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= np.reshape(self.evaluate(norm.beta), (1, 10, 1))
out /= np.reshape(self.evaluate(norm.gamma), (1, 10, 1))

self.assertAllClose(
out.mean(axis=(0, 1), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
out.std(axis=(0, 1), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)

def test_groupnorm_convnet(self):
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
norm = GroupNormalization(axis=1, input_shape=(3, 4, 4), groups=3)
model.add(norm)
model.compile(loss="mse", optimizer="sgd")

# centered = 5.0, variance = 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= np.reshape(self.evaluate(norm.beta), (1, 3, 1, 1))
out /= np.reshape(self.evaluate(norm.gamma), (1, 3, 1, 1))

self.assertAllClose(
np.mean(out, axis=(0, 2, 3), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
np.std(out, axis=(0, 2, 3), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)

def test_groupnorm_convnet_no_center_no_scale(self):
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
norm = GroupNormalization(
axis=-1, groups=2, center=False, scale=False, input_shape=(3, 4, 4)
)
model.add(norm)
model.compile(loss="mse", optimizer="sgd")
# centered and variance are 5.0 and 10.0, respectively
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)

self.assertAllClose(
np.mean(out, axis=(0, 2, 3), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
np.std(out, axis=(0, 2, 3), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)


if __name__ == "__main__":
tf.test.main()