@@ -68,14 +68,16 @@ def __init__(self,
6868 regularization = None ,
6969 grad_clip = None ,
7070 name = None ):
71+ # Because of the loop import, so place it in the function body
72+ from paddle .optimizer .lr_scheduler import _LRScheduler
7173 self ._parameter_list = list (
7274 parameter_list ) if parameter_list is not None else None
7375 self ._name = name
7476 if framework .in_dygraph_mode ():
75- if not isinstance (learning_rate , float ) and \
76- not isinstance ( learning_rate , LearningRateDecay ):
77+ if not isinstance (learning_rate ,
78+ ( float , LearningRateDecay , _LRScheduler ) ):
7779 raise TypeError (
78- "learning rate should be float or LearningRateDecay , got %s here"
80+ "learning rate should be float or _LRScheduler , got %s here"
7981 % type (learning_rate ))
8082 if self ._parameter_list is None :
8183 raise AttributeError (
@@ -90,11 +92,11 @@ def __init__(self,
9092 % regularization .__str__ ())
9193 break
9294 else :
93- if not isinstance (learning_rate , float ) and \
94- not isinstance ( learning_rate , framework .Variable ):
95+ if not isinstance (learning_rate ,
96+ ( float , framework .Variable , _LRScheduler ) ):
9597 raise TypeError (
96- "learning rate should be float or Variable , got %s here" %
97- type (learning_rate ))
98+ "learning rate should be float or _LRScheduler , got %s here"
99+ % type (learning_rate ))
98100
99101 if grad_clip is not None :
100102 if not isinstance (grad_clip , GradientClipBase ):
@@ -144,11 +146,15 @@ def state_dict(self):
144146 state_dict = adam.state_dict()
145147
146148 '''
149+ from paddle .optimizer .lr_scheduler import _LRScheduler
147150 state_dict = {}
148151 for k , v in self ._accumulators .items ():
149152 for para_name , var_tmp in v .items ():
150153 state_dict [var_tmp .name ] = var_tmp
151154 # global step if use lr decay
155+ if isinstance (self ._learning_rate , _LRScheduler ):
156+ state_dict ["LR_Scheduler" ] = self ._learning_rate .state_dict ()
157+ return state_dict
152158 if isinstance (self ._learning_rate , LearningRateDecay ):
153159 state_dict ["LR_Scheduler" ] = self ._learning_rate .state_dict ()
154160
@@ -192,6 +198,9 @@ def set_dict(self, state_dict):
192198 adam.set_dict(opti_state_dict)
193199
194200 '''
201+ from paddle .optimizer .lr_scheduler import _LRScheduler
202+ if isinstance (self ._learning_rate , _LRScheduler ):
203+ self ._learning_rate .set_dict (state_dict ["LR_Scheduler" ])
195204
196205 if isinstance (self ._learning_rate , LearningRateDecay ):
197206 self ._learning_rate .set_dict (state_dict ["LR_Scheduler" ])
@@ -252,6 +261,30 @@ def get_opti_var_name_list(self):
252261 return self ._opti_name_list
253262
254263 def _create_global_learning_rate (self ):
264+ from paddle .optimizer .lr_scheduler import _LRScheduler
265+ if isinstance (self ._learning_rate , _LRScheduler ):
266+ lr_var = self ._global_learning_rate ()
267+ # only create global lr_var once
268+ if not isinstance (lr_var , framework .Variable ):
269+ lr_name = unique_name .generate ('learning_rate' )
270+ self ._learning_rate ._var_name = lr_name
271+ lr_var = self .helper .create_global_variable (
272+ name = lr_name ,
273+ shape = [1 ],
274+ persistable = True ,
275+ stop_gradient = True ,
276+ dtype = 'float32' if self ._dtype is None else self ._dtype )
277+ main_prog = framework .default_main_program ()
278+ main_prog .lr_sheduler = self ._learning_rate
279+ main_prog .lr_var = lr_var
280+ self ._learning_rate_map [framework .default_main_program (
281+ )] = lr_var
282+
283+ lr_value = float (self ._learning_rate ())
284+ self .helper .set_variable_initializer (
285+ lr_var , initializer = Constant (value = lr_value ))
286+ return
287+
255288 if imperative_base .enabled ():
256289 # create learning rate Variable
257290 if isinstance (self ._learning_rate , float ):
0 commit comments