@@ -204,21 +204,19 @@ def iterations(self):
204204 def _track_variable (self , variable ):
205205 self ._tracker .add_to_store ("variables" , variable )
206206
207+ def _overwrite_variable_with_gradient (self , variable ):
208+ return getattr (variable , "overwrite_with_gradient" , False )
209+
207210 @tracking .no_automatic_dependency_tracking
208211 def build (self , variables ):
209212 if self .use_ema :
210- self ._model_variables_moving_average = []
213+ self ._model_variables_moving_average = self .add_optimizer_variables (
214+ variables , "average"
215+ )
211216 if self .gradient_accumulation_steps :
212217 self ._accumulated_gradients = []
213218 for i , variable in enumerate (variables ):
214219 self ._trainable_variables_indices [self ._var_key (variable )] = i
215- if self .use_ema :
216- self ._model_variables_moving_average .append (
217- self .add_variable_from_reference (
218- variable ,
219- name = "average" ,
220- )
221- )
222220 if self .gradient_accumulation_steps :
223221 self ._accumulated_gradients .append (
224222 self .add_variable_from_reference (
@@ -323,6 +321,49 @@ def add_variable_from_reference(
323321 name = name ,
324322 )
325323
324+ def add_optimizer_variables (
325+ self , trainable_variables , name , initializer = "zeros"
326+ ):
327+ """Add optimizer variables from the list of trainable model variables.
328+
329+ Create an optimizer variable based on the information of the supplied
330+ model variables. For example, in SGD optimizer momemtum, for each model
331+ variable, a corresponding momemtum variable is created of the same shape
332+ and dtype.
333+
334+ Note that trainable variables with `v.overwrite_with_gradient == True`
335+ will insert `None`, into the output list, since the optimizer variable
336+ will not be used anyways, and could be wasteful.
337+
338+ Args:
339+ trainable_variables: `keras.Variable`, the corresponding model
340+ variable to the optimizer variable to be created.
341+ name: The name prefix of the optimizer variable to be created. The
342+ variable name will follow the pattern
343+ `{variable_name}_{trainable_variable.name}`, e.g.,
344+ `momemtum/dense_1`. Defaults to `None`.
345+ initializer: Initializer object to use to populate the initial
346+ variable value, or string name of a built-in initializer (e.g.
347+ `"random_normal"`). If unspecified, defaults to `"zeros"`.
348+
349+ Returns:
350+ A list of optimizer variables, in the format of `keras.Variable`s.
351+ """
352+ optimizer_variables = []
353+ for variable in trainable_variables :
354+ if not self ._overwrite_variable_with_gradient (variable ):
355+ optimizer_variables .append (
356+ self .add_variable_from_reference (
357+ variable ,
358+ name = name ,
359+ initializer = initializer ,
360+ )
361+ )
362+ else :
363+ optimizer_variables .append (None )
364+
365+ return optimizer_variables
366+
326367 def _check_variables_are_known (self , variables ):
327368 for v in variables :
328369 if self ._var_key (v ) not in self ._trainable_variables_indices :
@@ -544,7 +585,8 @@ def _backend_update_step(self, grads, trainable_variables, learning_rate):
544585
545586 def _backend_reset_gradient_accumulators (self ):
546587 for g_acc in self ._accumulated_gradients :
547- g_acc .assign (ops .zeros (g_acc .shape , dtype = g_acc .dtype ))
588+ if g_acc is not None :
589+ g_acc .assign (ops .zeros (g_acc .shape , dtype = g_acc .dtype ))
548590
549591 def _backend_increment_gradient_accumulators (self , grads , acc_grads ):
550592 new_g_accs = [(g + acc_g ) for g , acc_g in zip (grads , acc_grads )]
@@ -711,8 +753,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
711753 After the update, the processed pairs will be filtered out.
712754 """
713755 # Shortcut for `tf.Variable` because it doesn't have a
714- # `overwrite_with_gradient` attr
715- if any ( not hasattr ( v , "overwrite_with_gradient" ) for v in vars ):
756+ # `overwrite_with_gradient` attr.
757+ if not any ( self . _overwrite_variable_with_gradient ( v ) for v in vars ):
716758 return grads , vars
717759
718760 # Shallow copies
@@ -722,7 +764,7 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
722764 # Iterate from right to left for safe popping
723765 for i in range (len (filtered_grads ) - 1 , - 1 , - 1 ):
724766 g , v = filtered_grads [i ], filtered_vars [i ]
725- if v . overwrite_with_gradient :
767+ if self . _overwrite_variable_with_gradient ( v ) :
726768 if self .gradient_accumulation_steps :
727769 # Utilize a stateless manner for JAX compatibility
728770 steps = self .gradient_accumulation_steps
@@ -886,11 +928,12 @@ def _update_model_variables_moving_average(self, trainable_variables):
886928 for var , average in zip (
887929 trainable_variables , self ._model_variables_moving_average
888930 ):
889- not_first_step = ops .not_equal (self .iterations , 0 )
890- momentum = (
891- ops .cast (not_first_step , var .dtype ) * self .ema_momentum
892- )
893- average .assign (momentum * average + (1 - momentum ) * var )
931+ if average is not None :
932+ not_first_step = ops .not_equal (self .iterations , 0 )
933+ momentum = (
934+ ops .cast (not_first_step , var .dtype ) * self .ema_momentum
935+ )
936+ average .assign (momentum * average + (1 - momentum ) * var )
894937
895938 def _overwrite_model_variables_with_average_value (
896939 self , trainable_variables
@@ -909,7 +952,8 @@ def _overwrite_model_variables_with_average_value(
909952 for var , average_var in zip (
910953 trainable_variables , self ._model_variables_moving_average
911954 ):
912- var .assign (average_var )
955+ if average_var is not None :
956+ var .assign (average_var )
913957
914958 def finalize_variable_values (self , var_list ):
915959 """Set the final value of model's trainable variables.
0 commit comments