Skip to content

Commit 4aedd36

Browse files
seanpmorganfacaiy
authored andcommitted
TF2 WeightNormalization (#29)
* FIX: Modify WeightNormalization for TF2 * FIX: Update bazel tests * FIX: Modify WeightNormalization for TF2 * FIX: Modify WeightNormalization for TF2
1 parent bfa919e commit 4aedd36

File tree

4 files changed

+69
-42
lines changed

4 files changed

+69
-42
lines changed

tensorflow_addons/layers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ py_library(
1919

2020
py_test(
2121
name = "layers_wrappers_py_test",
22+
size = "small",
2223
srcs = [
2324
"python/wrappers_test.py",
2425
],

tensorflow_addons/layers/python/wrappers.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@
1313
# limitations under the License.
1414
# =============================================================================
1515

16+
import tensorflow as tf
17+
1618
from tensorflow import name_scope
17-
from tensorflow.python.framework import ops
1819
from tensorflow.python.framework import tensor_shape
1920
from tensorflow.python.ops import array_ops
2021
from tensorflow.python.ops import nn_impl
22+
from tensorflow.python.ops import variables as tf_variables
23+
from tensorflow.python.ops.linalg_ops import norm
24+
from tensorflow.python.ops.math_ops import sqrt
25+
from tensorflow.python.ops.nn import moments
26+
2127
from tensorflow.python.keras import initializers
22-
from tensorflow.python.eager import context
23-
from tensorflow.python.keras.engine.base_layer import Layer
24-
from tensorflow.python.keras.engine.base_layer import InputSpec
28+
from tensorflow.python.keras.engine import base_layer
2529
from tensorflow.python.keras.layers import Wrapper
26-
from tensorflow.python.ops import variables as tf_variables
30+
from tensorflow_addons.utils.python import keras_utils
2731

2832

33+
@keras_utils.register_keras_custom_object
2934
class WeightNormalization(Wrapper):
3035
""" This wrapper reparameterizes a layer by decoupling the weight's
3136
magnitude and direction. This speeds up convergence by improving the
@@ -52,17 +57,12 @@ class WeightNormalization(Wrapper):
5257
ValueError: If `Layer` does not contain a `kernel` of weights
5358
NotImplementedError: If `data_init` is True and running graph execution
5459
"""
55-
def __init__(self, layer, data_init=False, **kwargs):
56-
if not isinstance(layer, Layer):
60+
def __init__(self, layer, data_init=True, **kwargs):
61+
if not isinstance(layer, base_layer.Layer):
5762
raise ValueError(
5863
'Please initialize `WeightNormalization` layer with a '
5964
'`Layer` instance. You passed: {input}'.format(input=layer))
6065

61-
if not context.executing_eagerly() and data_init:
62-
raise NotImplementedError(
63-
'Data dependent variable initialization is not available for '
64-
'graph execution')
65-
6666
self.initialized = True
6767
if data_init:
6868
self.initialized = False
@@ -75,26 +75,24 @@ def _compute_weights(self):
7575
with its norm """
7676
with name_scope('compute_weights'):
7777
self.layer.kernel = nn_impl.l2_normalize(
78-
self.layer.v, axis=self.norm_axes) * self.layer.g
78+
self.layer.v, axis=self.kernel_norm_axes) * self.layer.g
7979

8080
def _init_norm(self, weights):
8181
"""Set the norm of the weight vector"""
82-
from tensorflow.python.ops.linalg_ops import norm
8382
with name_scope('init_norm'):
8483
flat = array_ops.reshape(weights, [-1, self.layer_depth])
8584
return array_ops.reshape(norm(flat, axis=0), (self.layer_depth,))
8685

8786
def _data_dep_init(self, inputs):
88-
"""Data dependent initialization for eager execution"""
89-
from tensorflow.python.ops.nn import moments
90-
from tensorflow.python.ops.math_ops import sqrt
87+
"""Data dependent initialization"""
9188

9289
with name_scope('data_dep_init'):
9390
# Generate data dependent init values
9491
activation = self.layer.activation
9592
self.layer.activation = None
9693
x_init = self.layer.call(inputs)
97-
m_init, v_init = moments(x_init, self.norm_axes)
94+
data_norm_axes = list(range(x_init.shape.rank - 1))
95+
m_init, v_init = moments(x_init, data_norm_axes)
9896
scale_init = 1. / sqrt(v_init + 1e-10)
9997

10098
# Assign data dependent init values
@@ -106,7 +104,7 @@ def _data_dep_init(self, inputs):
106104
def build(self, input_shape):
107105
"""Build `Layer`"""
108106
input_shape = tensor_shape.TensorShape(input_shape).as_list()
109-
self.input_spec = InputSpec(shape=input_shape)
107+
self.input_spec = base_layer.InputSpec(shape=input_shape)
110108

111109
if not self.layer.built:
112110
self.layer.build(input_shape)
@@ -120,7 +118,7 @@ def build(self, input_shape):
120118

121119
# The kernel's filter or unit dimension is -1
122120
self.layer_depth = int(self.layer.kernel.shape[-1])
123-
self.norm_axes = list(range(self.layer.kernel.shape.ndims - 1))
121+
self.kernel_norm_axes = list(range(self.layer.kernel.shape.rank - 1))
124122

125123
self.layer.v = self.layer.kernel
126124
self.layer.g = self.layer.add_variable(
@@ -131,22 +129,22 @@ def build(self, input_shape):
131129
trainable=True,
132130
aggregation=tf_variables.VariableAggregation.MEAN)
133131

134-
with ops.control_dependencies([self.layer.g.assign(
135-
self._init_norm(self.layer.v))]):
136-
self._compute_weights()
132+
# TODO: Check if this needs control deps in TF2 graph mode
133+
self.layer.g.assign(self._init_norm(self.layer.v))
134+
self._compute_weights()
137135

138136
self.layer.built = True
139137

140138
super(WeightNormalization, self).build()
141139
self.built = True
142140

141+
@tf.function
143142
def call(self, inputs):
144143
"""Call `Layer`"""
145-
if context.executing_eagerly():
146-
if not self.initialized:
147-
self._data_dep_init(inputs)
148-
self._compute_weights() # Recompute weights for each forward pass
144+
if not self.initialized:
145+
self._data_dep_init(inputs)
149146

147+
self._compute_weights() # Recompute weights for each forward pass
150148
output = self.layer.call(inputs)
151149
return output
152150

tensorflow_addons/layers/python/wrappers_test.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
import numpy as np
2121
from tensorflow_addons.layers.python import wrappers
2222

23+
from tensorflow.python import keras
2324
from tensorflow.python.ops import random_ops
24-
from tensorflow.python.platform import test
25-
from tensorflow.python.layers import layers
26-
from tensorflow.python.training.rmsprop import RMSPropOptimizer
25+
from tensorflow.python.keras.optimizer_v2.rmsprop import RMSprop
2726

27+
from tensorflow.python.platform import test
2828
from tensorflow.python.framework import test_util as tf_test_util
29-
from tensorflow.python import keras
29+
from tensorflow.python.keras import testing_utils
3030

3131

3232
class WeightNormalizationTest(test.TestCase):
@@ -37,11 +37,25 @@ def test_weightnorm_dense_train(self):
3737
model.add(wrappers.WeightNormalization(
3838
keras.layers.Dense(2), input_shape=(3, 4)))
3939

40-
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
40+
model.compile(optimizer=RMSprop(learning_rate=0.001), loss='mse')
41+
model.fit(
42+
np.random.random((10, 3, 4)),
43+
np.random.random((10, 3, 2)),
44+
epochs=3,
45+
batch_size=10)
46+
self.assertTrue(hasattr(model.layers[0].layer, 'g'))
47+
48+
@tf_test_util.run_all_in_graph_and_eager_modes
49+
def test_weightnorm_dense_train_notinit(self):
50+
model = keras.models.Sequential()
51+
model.add(wrappers.WeightNormalization(
52+
keras.layers.Dense(2), input_shape=(3, 4), data_init=False))
53+
54+
model.compile(optimizer=RMSprop(learning_rate=0.001), loss='mse')
4155
model.fit(
4256
np.random.random((10, 3, 4)),
4357
np.random.random((10, 3, 2)),
44-
epochs=1,
58+
epochs=3,
4559
batch_size=10)
4660
self.assertTrue(hasattr(model.layers[0].layer, 'g'))
4761

@@ -53,31 +67,44 @@ def test_weightnorm_conv2d(self):
5367
input_shape=(4, 4, 3)))
5468

5569
model.add(keras.layers.Activation('relu'))
56-
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
57-
model.train_on_batch(
70+
model.compile(optimizer=RMSprop(learning_rate=0.001), loss='mse')
71+
model.fit(
5872
np.random.random((2, 4, 4, 3)),
59-
np.random.random((2, 4, 4, 5)))
73+
np.random.random((2, 4, 4, 5)),
74+
epochs=3,
75+
batch_size=10)
6076

6177
self.assertTrue(hasattr(model.layers[0].layer, 'g'))
6278

6379
@tf_test_util.run_all_in_graph_and_eager_modes
64-
def test_weight_norm_tflayers(self):
80+
def test_weightnorm_tflayers(self):
6581
images = random_ops.random_uniform((2, 4, 4, 3))
66-
wn_wrapper = wrappers.WeightNormalization(layers.Conv2D(32, [2, 2]),
67-
input_shape=(4, 4, 3))
82+
wn_wrapper = wrappers.WeightNormalization(
83+
keras.layers.Conv2D(32, [2, 2]), input_shape=(4, 4, 3))
6884
wn_wrapper.apply(images)
6985
self.assertTrue(hasattr(wn_wrapper.layer, 'g'))
7086

7187
@tf_test_util.run_all_in_graph_and_eager_modes
72-
def test_weight_norm_nonlayer(self):
88+
def test_weightnorm_nonlayer(self):
7389
images = random_ops.random_uniform((2, 4, 43))
7490
with self.assertRaises(ValueError):
7591
wrappers.WeightNormalization(images)
7692

7793
@tf_test_util.run_all_in_graph_and_eager_modes
78-
def test_weight_norm_nokernel(self):
94+
def test_weightnorm_nokernel(self):
7995
with self.assertRaises(ValueError):
80-
wrappers.WeightNormalization(layers.MaxPooling2D(2, 2)).build((2, 2))
96+
wrappers.WeightNormalization(
97+
keras.layers.MaxPooling2D(2, 2)).build((2, 2))
98+
99+
def test_weightnorm_keras(self):
100+
input_data = np.random.random((10, 3, 4)).astype(np.float32)
101+
outputs = testing_utils.layer_test(
102+
wrappers.WeightNormalization,
103+
kwargs={
104+
'layer': keras.layers.Dense(2),
105+
'input_shape': (3, 4)
106+
},
107+
input_data=input_data)
81108

82109

83110
if __name__ == "__main__":

tensorflow_addons/text/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ py_library(
4545

4646
py_test(
4747
name = "text_ops_py_test",
48+
size = "small",
4849
srcs = [
4950
"python/skip_gram_ops_test.py"
5051
],

0 commit comments

Comments
 (0)