@@ -129,6 +129,7 @@ def __init__(self,
129129        self ._params_name  =  set ()
130130        self ._apply_decay_param_fun  =  apply_decay_param_fun 
131131        self ._coeff  =  coeff 
132+         self ._lr_to_coeff  =  dict ()
132133        super (AdamW , self ).__init__ (
133134            learning_rate = learning_rate ,
134135            parameters = parameters ,
@@ -139,97 +140,48 @@ def __init__(self,
139140            name = name ,
140141            lazy_mode = lazy_mode )
141142
142-     def  _scale_parameters (self , params_and_grads ):
143+     def  _append_decoupled_weight_decay (self , block ,  param_and_grad ):
143144        """ 
144-         Adds  weight decay ops . 
145-             scaled_parameter  = parameter * coeff 
145+         Add decoupled  weight decay op . 
146+             parameter  = parameter - parameter  * coeff * lr  
146147
147148        Args: 
148-             params_and_grads: A list of (parameters, gradients) pairs, 
149+             block: block in which variable is to be created 
150+             param_and_grad: (parameters, gradients) pairs, 
149151                the parameters need to decay. 
150152        Raises: 
151153            Exception: The type of coeff and parameter is not consistent. 
152154        """ 
153- 
154-         scaled_params  =  []
155-         for  param , grad  in  params_and_grads :
156-             # If no gradient then we don't need to do anything 
157-             if  grad  is  None :
158-                 continue 
159-             if  self ._apply_decay_param_fun  is  not None  \
160-                     and  not  self ._apply_decay_param_fun (param .name ):
161-                 continue 
162- 
163-             if  isinstance (self ._coeff , float ):
164-                 assert  param .dtype  is  not paddle .fluid .core .VarDesc .VarType .FP32 , \
165-                     "the type of coeff(float) and parameter(%s) is not consistent." % (self ._coeff .dtype )
166-             else :
167-                 assert  self ._coeff .dtype  ==  param .dtype , \
168-                     "the type of coeff(%s) and parameter(%s) is not consistent." % (self ._coeff .dtype , param .dtype )
169-             if  isinstance (self ._learning_rate , float ):
170-                 learning_rate  =  self ._learning_rate 
171-             else :
172-                 learning_rate  =  self ._learning_rate ()
173-             with  param .block .program ._optimized_guard (
174-                 [param , grad ]), framework .name_scope ('weight decay' ):
175-                 scaled_params .append (
176-                     (param , grad , param  *  self ._coeff  *  learning_rate ))
177-                 if  param .name  not  in self ._params_name :
178-                     self ._params_name .add (param .name )
179-                     param  =  param  *  self ._coeff 
180-         return  scaled_params 
181- 
182-     @imperative_base .no_grad  
183-     def  minimize (self ,
184-                  loss ,
185-                  startup_program = None ,
186-                  parameters = None ,
187-                  no_grad_set = None ):
188-         parameters  =  parameters  if  parameters  \
189-             else  self ._parameter_list 
190- 
191-         params_grads  =  self .backward (
192-             loss = loss ,
193-             startup_program = startup_program ,
194-             parameters = parameters ,
195-             no_grad_set = no_grad_set )
196-         scaled_params  =  self ._scale_parameters (params_grads )
197-         for  p_grad_sgrad  in  scaled_params :
198-             param , grad , scaled_param  =  p_grad_sgrad 
199-             with  param .block .program ._optimized_guard (
200-                 [param , grad ]), framework .name_scope ('weight decay' ):
201-                 updated_param  =  paddle .fluid .layers .elementwise_sub (
202-                     x = param , y = scaled_param )
203-                 paddle .fluid .layers .assign (input = updated_param , output = param )
204- 
205-         optimize_ops  =  self ._apply_optimize (
206-             loss = loss ,
207-             params_grads = params_grads ,
208-             startup_program = startup_program )
209-         return  optimize_ops , params_grads 
210- 
211-     @framework .dygraph_only  
212-     @imperative_base .no_grad  
213-     def  step (self ):
214-         self ._dtype  =  None 
215-         params_grads  =  []
216-         for  param  in  self ._parameter_list :
217-             if  not  param .trainable :
218-                 continue 
219-             if  param ._grad_ivar () is  not None :
220-                 grad_var  =  param ._grad_ivar ()
221-                 params_grads .append ((param , grad_var ))
222- 
223-         scaled_params  =  self ._scale_parameters (params_grads )
224-         for  p_grad_sgrad  in  scaled_params :
225-             param , grad , scaled_param  =  p_grad_sgrad 
226-             with  param .block .program ._optimized_guard (
227-                 [param , grad ]), framework .name_scope ('weight decay' ):
228-                 updated_param  =  paddle .fluid .layers .elementwise_sub (
229-                     x = param , y = scaled_param )
230-                 paddle .fluid .layers .assign (input = updated_param , output = param )
231-         self ._apply_optimize (
232-             loss = None , startup_program = None , params_grads = params_grads )
155+         param , grad  =  param_and_grad 
156+ 
157+         if  self ._apply_decay_param_fun  is  not None  \
158+                 and  not  self ._apply_decay_param_fun (param .name ):
159+             return 
160+ 
161+         if  isinstance (self ._learning_rate , float ):
162+             learning_rate  =  self ._learning_rate 
163+         else :
164+             # NOTE. We add this function to the _append_optimize_op(), 
165+             # for we must make sure _create_param_lr() be called after 
166+             # optimizer._create_global_learning_rate(). 
167+             learning_rate  =  self ._create_param_lr (param_and_grad )
168+ 
169+         with  block .program ._optimized_guard (
170+             [param , grad ]), framework .name_scope ('weight decay' ):
171+             self ._params_name .add (param .name )
172+ 
173+             # If it has been calculated, the result will be reused 
174+             decay_coeff  =  self ._lr_to_coeff .get (learning_rate , None )
175+             if  decay_coeff  is  None :
176+                 decay_coeff  =  1.0  -  learning_rate  *  self ._coeff 
177+                 self ._lr_to_coeff [learning_rate ] =  decay_coeff 
178+ 
179+             scaled_param  =  param  *  decay_coeff 
180+             paddle .fluid .layers .assign (input = scaled_param , output = param )
181+ 
182+     def  _append_optimize_op (self , block , param_and_grad ):
183+         self ._append_decoupled_weight_decay (block , param_and_grad )
184+         return  super (AdamW , self )._append_optimize_op (block , param_and_grad )
233185
234186    def  __str__ (self ):
235187        return  " " .join (["Weight Decay, params:" , "," .join (self ._params_name )])
0 commit comments