@@ -198,18 +198,19 @@ def apply(self, grad_params, gpu=0):
198
198
199
199
200
200
class ClipAdamOptimizer (optimizer .Optimizer ):
201
- def __init__ (self , learning_rate = 3e-4 , beta1 = 0.9 , beta2 = 0.999 , epsilon = 1e-8 , clip_sigmas = 0.0 , grad_scale = 1.0 , zero_nans = True , name = "ClipAdam" ):
201
+ def __init__ (self , learning_rate = 3e-4 , beta1 = 0.9 , beta2 = 0.999 , epsilon = 1e-8 , clip_sigmas = 0.0 , grad_scale = 1.0 , sat_infs = None , zero_nans = True , name = "ClipAdam" ):
202
202
super ().__init__ (False , name )
203
203
self .beta1 = beta1
204
204
self .beta2 = beta2
205
205
self .epsilon = epsilon
206
+ self .sat_infs = sat_infs
206
207
self .zero_nans = zero_nans
207
208
self .name = name
208
209
209
210
with tf .device ("/cpu:0" ), tf .variable_scope ("adam_beta" ):
210
211
211
- if type (learning_rate ) is float :
212
- learning_rate = tf .constant (learning_rate )
212
+ if type (learning_rate ) in ( float , int ) :
213
+ learning_rate = tf .constant (float ( learning_rate ) )
213
214
if type (clip_sigmas ) in (float , int ):
214
215
clip_sigmas = tf .constant (float (clip_sigmas ))
215
216
if type (grad_scale ) in (float , int ):
@@ -240,9 +241,12 @@ def _apply_dense(self, grad, param):
240
241
m = self .get_slot (param , "m" )
241
242
v = self .get_slot (param , "v" )
242
243
244
+ # a float32 grad could still could contain infs from upstream fp16 math
245
+ sat_infs = grad .dtype is tf .float16 if self .sat_infs is None else self .sat_infs
246
+
243
247
return adam_op (grad , param , m , v , self .lr , self .grad_scale , self .clip_sigma ,
244
248
decay_mean = self .beta1 , decay_var = self .beta2 , epsilon = self .epsilon ,
245
- zero_nans = self .zero_nans , lazy_update = hasattr (grad , "lazy" )).out_param
249
+ sat_infs = sat_infs , zero_nans = self .zero_nans , lazy_update = hasattr (grad , "lazy" )).out_param
246
250
247
251
def _apply_sparse (self , grad , param ):
248
252
raise NotImplementedError ("Sparse gradient updates are not supported." )
@@ -255,3 +259,93 @@ def _finish(self, update_ops, name_scope):
255
259
256
260
return tf .group (* update_ops + [update_beta1 , update_beta2 ], name = name_scope )
257
261
262
+
263
+ class AdamOptimizer (ClipAdamOptimizer ):
264
+ def __init__ (self , learning_rate = 3e-4 , beta1 = 0.9 , beta2 = 0.999 , epsilon = 1e-8 , grad_scale = 1.0 , sat_infs = None , zero_nans = True , name = "Adam" ):
265
+ super ().__init__ (learning_rate = learning_rate , beta1 = beta1 , beta2 = beta2 , epsilon = epsilon , grad_scale = grad_scale , sat_infs = sat_infs , zero_nans = zero_nans , name = name )
266
+
267
+
268
+ ############################## ClipAdamOptimizer #####################################
269
+
270
+ adafactor1d_op = _op_module .adafactor1d
271
+ adafactor2d_op = _op_module .adafactor2d
272
+
273
+ class AdafactorOptimizer (optimizer .Optimizer ):
274
+ def __init__ (self , learning_rate = 5e-4 , beta2 = 0.999 , epsilon = 1e-30 , clip_thresh = 1.0 , grad_scale = 1.0 , sat_infs = None , zero_nans = True , name = "Adafactor" ):
275
+ super ().__init__ (False , name )
276
+ self .epsilon = epsilon
277
+ self .sat_infs = sat_infs
278
+ self .zero_nans = zero_nans
279
+ self .name = name
280
+
281
+ with tf .device ("/cpu:0" ), tf .variable_scope ("adafactor_decay" ):
282
+
283
+ if type (learning_rate ) in (float , int ):
284
+ learning_rate = tf .constant (float (learning_rate ))
285
+ if type (clip_thresh ) in (float , int ):
286
+ clip_thresh = tf .constant (float (clip_thresh ))
287
+ if type (grad_scale ) in (float , int ):
288
+ grad_scale = tf .constant (float (grad_scale ))
289
+ one = tf .constant (1.0 )
290
+
291
+ self .decay1_power = tf .Variable (initial_value = beta2 , name = "decay1_power" , trainable = False )
292
+ self .decay2_power = tf .Variable (initial_value = beta2 * beta2 , name = "decay2_power" , trainable = False )
293
+ self .learn_rate = learning_rate
294
+ self .clip_thresh = clip_thresh
295
+ self .grad_scale = grad_scale
296
+ self .decay_t = tf .constant (beta2 )
297
+ self .decay = self .decay_t * (one - self .decay1_power ) / (one - self .decay2_power )
298
+
299
+ def _get_beta_accumulators (self ):
300
+ return self .decay1_power , self .decay2_power
301
+
302
+ def _non_slot_variables (self ):
303
+ return self ._get_beta_accumulators ()
304
+
305
+ def _create_slots (self , params ):
306
+ # Create slots for the first and second moments.
307
+ for param in params :
308
+ if param .shape .ndims == 2 and param .shape [0 ].value > 1 :
309
+ self ._get_or_make_slot (param , tf .zeros (param .shape [1 ].value ), "cv" , self .name + "CV" )
310
+ self ._get_or_make_slot (param , tf .zeros (param .shape [0 ].value ), "rv" , self .name + "RV" )
311
+ elif param .shape .ndims == 1 or (param .shape .ndims == 2 and param .shape [0 ].value == 1 ):
312
+ self ._get_or_make_slot (param , tf .zeros (param .shape .num_elements ()), "cv" , self .name + "CV" )
313
+ else :
314
+ raise ValueError ("only 1 or 2d params are supported" )
315
+
316
+ def _apply_dense (self , grad , param ):
317
+
318
+ # a float32 grad could still could contain infs from upstream fp16 math
319
+ sat_infs = grad .dtype is tf .float16 if self .sat_infs is None else self .sat_infs
320
+
321
+ if param .shape .ndims == 2 and param .shape [0 ].value > 1 :
322
+
323
+ cv = self .get_slot (param , "cv" )
324
+ rv = self .get_slot (param , "rv" )
325
+
326
+ return adafactor2d_op (param , cv , rv , grad ,
327
+ self .decay , self .learn_rate , self .grad_scale , self .clip_thresh ,
328
+ epsilon = self .epsilon , sat_infs = sat_infs , zero_nans = self .zero_nans ).out_param
329
+
330
+ elif param .shape .ndims == 1 or (param .shape .ndims == 2 and param .shape [0 ].value == 1 ):
331
+
332
+ cv = self .get_slot (param , "cv" )
333
+
334
+ return adafactor1d_op (param , cv , grad ,
335
+ self .decay , self .learn_rate , self .grad_scale , self .clip_thresh ,
336
+ epsilon = self .epsilon , sat_infs = sat_infs , zero_nans = self .zero_nans ).out_param
337
+ else :
338
+ raise ValueError ("only 1 or 2d params are supported" )
339
+
340
+ def _apply_sparse (self , grad , param ):
341
+ raise NotImplementedError ("Sparse gradient updates are not supported." )
342
+
343
+ def _finish (self , update_ops , name_scope ):
344
+ # Update the power accumulators.
345
+ with ops .control_dependencies ([ self .decay ]), tf .device ("/cpu:0" ):
346
+ update_decay1 = self .decay1_power .assign (self .decay1_power * self .decay_t )
347
+ update_decay2 = self .decay2_power .assign (self .decay2_power * self .decay_t )
348
+
349
+ return tf .group (* update_ops + [update_decay1 , update_decay2 ], name = name_scope )
350
+
351
+
0 commit comments