13
13
# limitations under the License.
14
14
# =============================================================================
15
15
16
+ import tensorflow as tf
17
+
16
18
from tensorflow import name_scope
17
- from tensorflow .python .framework import ops
18
19
from tensorflow .python .framework import tensor_shape
19
20
from tensorflow .python .ops import array_ops
20
21
from tensorflow .python .ops import nn_impl
22
+ from tensorflow .python .ops import variables as tf_variables
23
+ from tensorflow .python .ops .linalg_ops import norm
24
+ from tensorflow .python .ops .math_ops import sqrt
25
+ from tensorflow .python .ops .nn import moments
26
+
21
27
from tensorflow .python .keras import initializers
22
- from tensorflow .python .eager import context
23
- from tensorflow .python .keras .engine .base_layer import Layer
24
- from tensorflow .python .keras .engine .base_layer import InputSpec
28
+ from tensorflow .python .keras .engine import base_layer
25
29
from tensorflow .python .keras .layers import Wrapper
26
- from tensorflow . python . ops import variables as tf_variables
30
+ from tensorflow_addons . utils . python import keras_utils
27
31
28
32
33
+ @keras_utils .register_keras_custom_object
29
34
class WeightNormalization (Wrapper ):
30
35
""" This wrapper reparameterizes a layer by decoupling the weight's
31
36
magnitude and direction. This speeds up convergence by improving the
@@ -52,17 +57,12 @@ class WeightNormalization(Wrapper):
52
57
ValueError: If `Layer` does not contain a `kernel` of weights
53
58
NotImplementedError: If `data_init` is True and running graph execution
54
59
"""
55
- def __init__ (self , layer , data_init = False , ** kwargs ):
56
- if not isinstance (layer , Layer ):
60
+ def __init__ (self , layer , data_init = True , ** kwargs ):
61
+ if not isinstance (layer , base_layer . Layer ):
57
62
raise ValueError (
58
63
'Please initialize `WeightNormalization` layer with a '
59
64
'`Layer` instance. You passed: {input}' .format (input = layer ))
60
65
61
- if not context .executing_eagerly () and data_init :
62
- raise NotImplementedError (
63
- 'Data dependent variable initialization is not available for '
64
- 'graph execution' )
65
-
66
66
self .initialized = True
67
67
if data_init :
68
68
self .initialized = False
@@ -75,26 +75,24 @@ def _compute_weights(self):
75
75
with its norm """
76
76
with name_scope ('compute_weights' ):
77
77
self .layer .kernel = nn_impl .l2_normalize (
78
- self .layer .v , axis = self .norm_axes ) * self .layer .g
78
+ self .layer .v , axis = self .kernel_norm_axes ) * self .layer .g
79
79
80
80
def _init_norm (self , weights ):
81
81
"""Set the norm of the weight vector"""
82
- from tensorflow .python .ops .linalg_ops import norm
83
82
with name_scope ('init_norm' ):
84
83
flat = array_ops .reshape (weights , [- 1 , self .layer_depth ])
85
84
return array_ops .reshape (norm (flat , axis = 0 ), (self .layer_depth ,))
86
85
87
86
def _data_dep_init (self , inputs ):
88
- """Data dependent initialization for eager execution"""
89
- from tensorflow .python .ops .nn import moments
90
- from tensorflow .python .ops .math_ops import sqrt
87
+ """Data dependent initialization"""
91
88
92
89
with name_scope ('data_dep_init' ):
93
90
# Generate data dependent init values
94
91
activation = self .layer .activation
95
92
self .layer .activation = None
96
93
x_init = self .layer .call (inputs )
97
- m_init , v_init = moments (x_init , self .norm_axes )
94
+ data_norm_axes = list (range (x_init .shape .rank - 1 ))
95
+ m_init , v_init = moments (x_init , data_norm_axes )
98
96
scale_init = 1. / sqrt (v_init + 1e-10 )
99
97
100
98
# Assign data dependent init values
@@ -106,7 +104,7 @@ def _data_dep_init(self, inputs):
106
104
def build (self , input_shape ):
107
105
"""Build `Layer`"""
108
106
input_shape = tensor_shape .TensorShape (input_shape ).as_list ()
109
- self .input_spec = InputSpec (shape = input_shape )
107
+ self .input_spec = base_layer . InputSpec (shape = input_shape )
110
108
111
109
if not self .layer .built :
112
110
self .layer .build (input_shape )
@@ -120,7 +118,7 @@ def build(self, input_shape):
120
118
121
119
# The kernel's filter or unit dimension is -1
122
120
self .layer_depth = int (self .layer .kernel .shape [- 1 ])
123
- self .norm_axes = list (range (self .layer .kernel .shape .ndims - 1 ))
121
+ self .kernel_norm_axes = list (range (self .layer .kernel .shape .rank - 1 ))
124
122
125
123
self .layer .v = self .layer .kernel
126
124
self .layer .g = self .layer .add_variable (
@@ -131,22 +129,22 @@ def build(self, input_shape):
131
129
trainable = True ,
132
130
aggregation = tf_variables .VariableAggregation .MEAN )
133
131
134
- with ops . control_dependencies ([ self . layer . g . assign (
135
- self ._init_norm (self .layer .v ))]):
136
- self ._compute_weights ()
132
+ # TODO: Check if this needs control deps in TF2 graph mode
133
+ self .layer . g . assign ( self . _init_norm (self .layer .v ))
134
+ self ._compute_weights ()
137
135
138
136
self .layer .built = True
139
137
140
138
super (WeightNormalization , self ).build ()
141
139
self .built = True
142
140
141
+ @tf .function
143
142
def call (self , inputs ):
144
143
"""Call `Layer`"""
145
- if context .executing_eagerly ():
146
- if not self .initialized :
147
- self ._data_dep_init (inputs )
148
- self ._compute_weights () # Recompute weights for each forward pass
144
+ if not self .initialized :
145
+ self ._data_dep_init (inputs )
149
146
147
+ self ._compute_weights () # Recompute weights for each forward pass
150
148
output = self .layer .call (inputs )
151
149
return output
152
150
0 commit comments