@@ -49,20 +49,13 @@ def __init__(
4949 super ().__init__ (name , ** kwargs )
5050 self ._optimizer = tf .keras .optimizers .get (inner_optimizer )
5151 self ._step = None
52- self ._gradients = {}
5352 self ._accum_steps = accum_steps
5453 self ._reduction = reduction
5554
5655 def _accum_grad (grads_and_vars ):
57- with tf .init_scope ():
58- if not self ._gradients :
59- for grad , var in grads_and_vars :
60- self ._gradients [var .ref ()] = tf .Variable (
61- tf .zeros_like (var ), trainable = False
62- )
6356 new_grads_and_vars = []
6457 for grad , var in grads_and_vars :
65- handle = self ._gradients [ var . ref ()]
58+ handle = self .get_slot ( var , "ga" )
6659
6760 if isinstance (grad , tf .IndexedSlices ):
6861 handle .scatter_add (grad )
@@ -84,9 +77,11 @@ def _get_grad():
8477 values = tf .gather (new_grad , indices )
8578 dense_shape = tf .constant (new_grad .shape .as_list ())
8679 handle .assign (
87- tf .zeros_like (handle ), use_locking = self ._use_locking
80+ tf .zeros_like (handle ),
81+ use_locking = self ._use_locking ,
82+ read_value = False ,
8883 )
89- return values , tf .cast (indices , tf . int32 ), dense_shape
84+ return values , tf .cast (indices , grad . indices . dtype ), dense_shape
9085
9186 values , indices , dense_shape = tf .cond (
9287 self .step % self ._accum_steps == 0 ,
@@ -100,14 +95,18 @@ def _get_grad():
10095 new_grad = tf .IndexedSlices (values , indices , dense_shape )
10196 new_grads_and_vars .append ((new_grad , var ))
10297 else :
103- handle .assign_add (grad )
98+ handle .assign_add (
99+ grad , use_locking = self ._use_locking , read_value = False
100+ )
104101
105102 def _get_grad ():
106103 new_grad = handle .read_value ()
107104 if self ._reduction == "MEAN" :
108105 new_grad /= tf .cast (self ._accum_steps , new_grad .dtype )
109106 handle .assign (
110- tf .zeros_like (handle ), use_locking = self ._use_locking
107+ tf .zeros_like (handle ),
108+ use_locking = self ._use_locking ,
109+ read_value = False ,
111110 )
112111 return new_grad
113112
@@ -119,11 +118,39 @@ def _get_grad():
119118 new_grads_and_vars .append ((new_grad , var ))
120119 return new_grads_and_vars
121120
122- self ._optimizer . gradient_transformers .append (_accum_grad )
121+ self .gradient_transformers .append (_accum_grad )
123122 self ._iterations = self ._optimizer .iterations
124123
125124 def _create_slots (self , var_list ):
126125 self ._optimizer ._create_slots (var_list = var_list )
126+ for var in var_list :
127+ self .add_slot (var , "ga" )
128+
129+ def _resource_apply_dense (self , grad , handle , apply_state ):
130+ if "apply_state" in self ._optimizer ._dense_apply_args :
131+ return self .inner_optimizer ._resource_apply_dense (grad , handle , apply_state )
132+ else :
133+ return self .inner_optimizer ._resource_apply_dense (grad , handle )
134+
135+ def _resource_apply_sparse (self , grad , handle , indices , apply_state ):
136+ if "apply_state" in self ._optimizer ._sparse_apply_args :
137+ return self .inner_optimizer ._resource_apply_sparse (
138+ grad , handle , indices , apply_state = apply_state
139+ )
140+ else :
141+ return self .inner_optimizer ._resource_apply_sparse (grad , handle , indices )
142+
143+ def _resource_apply_sparse_duplicate_indices (
144+ self , grad , handle , indices , apply_state = None
145+ ):
146+ if "apply_state" in self ._optimizer ._sparse_apply_args :
147+ return self .inner_optimizer ._resource_apply_sparse_duplicate_indices (
148+ grad , handle , indices , apply_state = apply_state
149+ )
150+ else :
151+ return self .inner_optimizer ._resource_apply_sparse_duplicate_indices (
152+ grad , handle , indices
153+ )
127154
128155 @property
129156 def step (self ):
@@ -133,7 +160,6 @@ def step(self):
133160 self ._step = self .add_weight (
134161 "iter" ,
135162 shape = [],
136- initializer = "ones" ,
137163 dtype = tf .int64 ,
138164 trainable = False ,
139165 aggregation = tf .VariableAggregation .ONLY_FIRST_REPLICA ,
@@ -151,49 +177,15 @@ def step(self, variable):
151177 self ._step = variable
152178 self ._weights .append (self ._step )
153179
154- @property
155- def gradients (self ):
156- """The accumulated gradients on the current replica."""
157- if not self ._gradients :
158- raise ValueError (
159- "The accumulator should be called first to initialize the gradients"
160- )
161- return list (
162- gradient .read_value () if gradient is not None else gradient
163- for _ , gradient in self ._gradients
164- )
165-
166180 def apply_gradients (self , grads_and_vars , name = None , ** kwargs ):
167- train_op = self ._optimizer .apply_gradients (grads_and_vars , name , ** kwargs )
168- with tf .control_dependencies ([train_op ]):
169- with tf .control_dependencies (
170- [
171- self ._optimizer .iterations .assign_add (
172- tf .cast (self .step % self ._accum_steps == 0 , tf .int64 ),
173- read_value = False ,
174- )
175- ]
176- ):
177- return self .step .assign_add (1 , read_value = False )
178-
179- def reset (self ):
180- """Resets the accumulated gradients on the current replica."""
181- assign_ops = []
182- if not self ._gradients :
183- return assign_ops
184-
185- for _ , gradient in self ._gradients :
186- if gradient is not None :
187- assign_ops .append (
188- gradient .assign (
189- tf .zeros_like (gradient ),
190- use_locking = self ._use_locking ,
191- read_value = False ,
192- )
181+ with tf .control_dependencies ([self .step .assign_add (1 , read_value = False )]):
182+ train_op = super ().apply_gradients (grads_and_vars , name , ** kwargs )
183+ with tf .control_dependencies ([train_op ]):
184+ return self .iterations .assign_sub (
185+ tf .cast (self .step % self ._accum_steps != 0 , tf .int64 ),
186+ read_value = False ,
193187 )
194188
195- return tf .group (assign_ops )
196-
197189 @property
198190 def inner_optimizer (self ):
199191 """The optimizer that this LossScaleOptimizer is wrapping."""
0 commit comments