@@ -36,6 +36,22 @@ struct ApplyGradientDescent<CPUDevice, T> {
36
36
}
37
37
};
38
38
39
+ template <typename T>
40
+ struct ApplyAdadelta <CPUDevice, T> {
41
+ void operator ()(const CPUDevice& d, typename TTypes<T>::Flat var,
42
+ typename TTypes<T>::Flat accum,
43
+ typename TTypes<T>::Flat accum_update,
44
+ typename TTypes<T>::ConstScalar lr,
45
+ typename TTypes<T>::ConstScalar rho,
46
+ typename TTypes<T>::ConstScalar epsilon,
47
+ typename TTypes<T>::ConstFlat grad) {
48
+ accum.device (d) = accum * rho () + grad.square () * (1 - rho ());
49
+ const auto update = accum_update * (accum + epsilon ()).rsqrt () * grad;
50
+ accum_update.device (d) = accum_update * rho () + update.square () * (1 - rho ());
51
+ var.device (d) -= update * lr ();
52
+ }
53
+ };
54
+
39
55
template <typename T>
40
56
struct ApplyAdagrad <CPUDevice, T> {
41
57
void operator ()(const CPUDevice& d, typename TTypes<T>::Flat var,
@@ -224,6 +240,266 @@ REGISTER_KERNELS(GPU, double);
224
240
#endif
225
241
#undef REGISTER_KERNELS
226
242
243
+ template <typename Device, typename T>
244
+ class ApplyAdadeltaOp : public OpKernel {
245
+ public:
246
+ explicit ApplyAdadeltaOp (OpKernelConstruction* ctx) : OpKernel(ctx) {
247
+ OP_REQUIRES_OK (ctx, ctx->GetAttr (" use_locking" , &use_exclusive_lock_));
248
+ }
249
+
250
+ void Compute (OpKernelContext* ctx) override {
251
+ if (use_exclusive_lock_) {
252
+ mutex_lock l1 (*ctx->input_ref_mutex (0 ));
253
+ // Don't try to acquire a lock on the second ref as they share the same
254
+ // mutex.
255
+ //
256
+ // mutex_lock l2(*ctx->input_ref_mutex(1));
257
+ DoValidate (ctx);
258
+ if (!ctx->status ().ok ()) return ;
259
+ DoCompute (ctx);
260
+ } else {
261
+ DoValidate (ctx);
262
+ if (!ctx->status ().ok ()) return ;
263
+ DoCompute (ctx);
264
+ }
265
+ ctx->forward_ref_input_to_ref_output (0 , 0 );
266
+ }
267
+
268
+ private:
269
+ bool use_exclusive_lock_;
270
+
271
+ void DoValidate (OpKernelContext* ctx) {
272
+ Tensor var = ctx->mutable_input (0 , use_exclusive_lock_);
273
+ Tensor accum = ctx->mutable_input (1 , use_exclusive_lock_);
274
+ Tensor accum_update = ctx->mutable_input (2 , use_exclusive_lock_);
275
+
276
+ OP_REQUIRES (
277
+ ctx, var.IsInitialized (),
278
+ errors::FailedPrecondition (
279
+ " Attempting to use uninitialized variables: " , def ().input (0 )));
280
+ OP_REQUIRES (
281
+ ctx, accum.IsInitialized (),
282
+ errors::FailedPrecondition (
283
+ " Attempting to use uninitialized variables: " , def ().input (1 )));
284
+ OP_REQUIRES (
285
+ ctx, accum_update.IsInitialized (),
286
+ errors::FailedPrecondition (
287
+ " Attempting to use uninitialized variables: " , def ().input (2 )));
288
+
289
+ const Tensor& lr = ctx->input (3 );
290
+ const Tensor& rho = ctx->input (4 );
291
+ const Tensor& epsilon = ctx->input (5 );
292
+ const Tensor& grad = ctx->input (6 );
293
+
294
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (lr.shape ()),
295
+ errors::InvalidArgument (" lr is not a scalar: " ,
296
+ lr.shape ().DebugString ()));
297
+
298
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (rho.shape ()),
299
+ errors::InvalidArgument (" rho is not a scalar: " ,
300
+ rho.shape ().DebugString ()));
301
+
302
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (epsilon.shape ()),
303
+ errors::InvalidArgument (" epsilon is not a scalar: " ,
304
+ epsilon.shape ().DebugString ()));
305
+
306
+ OP_REQUIRES (
307
+ ctx, var.shape ().IsSameSize (accum.shape ()),
308
+ errors::InvalidArgument (" var and accum do not have the same shape" ,
309
+ var.shape ().DebugString (), " " ,
310
+ accum.shape ().DebugString ()));
311
+ OP_REQUIRES (
312
+ ctx, var.shape ().IsSameSize (grad.shape ()),
313
+ errors::InvalidArgument (" var and grad do not have the same shape" ,
314
+ var.shape ().DebugString (), " " ,
315
+ grad.shape ().DebugString ()));
316
+ }
317
+
318
+ void DoCompute (OpKernelContext* ctx) {
319
+ const Device& device = ctx->template eigen_device <Device>();
320
+ Tensor var = ctx->mutable_input (0 , use_exclusive_lock_);
321
+ Tensor accum = ctx->mutable_input (1 , use_exclusive_lock_);
322
+ Tensor accum_update = ctx->mutable_input (2 , use_exclusive_lock_);
323
+
324
+ const Tensor& lr = ctx->input (3 );
325
+ const Tensor& rho = ctx->input (4 );
326
+ const Tensor& epsilon = ctx->input (5 );
327
+ const Tensor& grad = ctx->input (6 );
328
+
329
+ functor::ApplyAdadelta<Device, T>()(device, var.flat <T>(), accum.flat <T>(),
330
+ accum_update.flat <T>(), lr.scalar <T>(),
331
+ rho.scalar <T>(), epsilon.scalar <T>(),
332
+ grad.flat <T>());
333
+ }
334
+ };
335
+
336
+ typedef Eigen::ThreadPoolDevice CPUDevice;
337
+ typedef Eigen::GpuDevice GPUDevice;
338
+
339
+ #define REGISTER_KERNELS (D, T ) \
340
+ REGISTER_KERNEL_BUILDER ( \
341
+ Name (" ApplyAdadelta" ).Device(DEVICE_##D).TypeConstraint<T>(" T" ), \
342
+ ApplyAdadeltaOp<D##Device, T>);
343
+
344
+ REGISTER_KERNELS (CPU, float );
345
+ REGISTER_KERNELS (CPU, double );
346
+
347
+ #if GOOGLE_CUDA
348
+ // Forward declarations of the functor specializations for GPU.
349
+ namespace functor {
350
+ #define DECLARE_GPU_SPEC (T ) \
351
+ template <> \
352
+ void ApplyAdadelta<GPUDevice, T>::operator ()( \
353
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
354
+ typename TTypes<T>::Flat accum, \
355
+ typename TTypes<T>::Flat accum_update, \
356
+ typename TTypes<T>::ConstScalar lr, \
357
+ typename TTypes<T>::ConstScalar rho, \
358
+ typename TTypes<T>::ConstScalar epsilon, \
359
+ typename TTypes<T>::ConstFlat grad); \
360
+ extern template struct ApplyAdadelta <GPUDevice, T>;
361
+ DECLARE_GPU_SPEC (float );
362
+ DECLARE_GPU_SPEC (double );
363
+ #undef DECLARE_GPU_SPEC
364
+ } // namespace functor
365
+
366
+ REGISTER_KERNELS (GPU, float );
367
+ REGISTER_KERNELS (GPU, double );
368
+ #endif
369
+ #undef REGISTER_KERNELS
370
+
371
+ // Note, this op works on cpu only.
372
+ template <typename T, typename Tindex>
373
+ class SparseApplyAdadeltaOp : public OpKernel {
374
+ public:
375
+ explicit SparseApplyAdadeltaOp (OpKernelConstruction* ctx) : OpKernel(ctx) {
376
+ OP_REQUIRES_OK (ctx, ctx->GetAttr (" use_locking" , &use_exclusive_lock_));
377
+ }
378
+
379
+ void Compute (OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
380
+ mutex* mu_var = ctx->input_ref_mutex (0 );
381
+ // mu_accum is actually the same mutex as mu_var since currently we use a
382
+ // global mutex.
383
+ //
384
+ // mutex* mu_accum = ctx->input_ref_mutex(1);
385
+ if (use_exclusive_lock_) {
386
+ mu_var->lock ();
387
+ }
388
+ Tensor var = ctx->mutable_input (0 , use_exclusive_lock_);
389
+ Tensor accum_grad = ctx->mutable_input (1 , use_exclusive_lock_);
390
+ Tensor accum_update = ctx->mutable_input (2 , use_exclusive_lock_);
391
+ OP_REQUIRES (
392
+ ctx, var.IsInitialized (),
393
+ errors::FailedPrecondition (
394
+ " Attempting to use uninitialized variables: " , def ().input (0 )));
395
+ OP_REQUIRES (
396
+ ctx, accum_grad.IsInitialized (),
397
+ errors::FailedPrecondition (
398
+ " Attempting to use uninitialized variables: " , def ().input (1 )));
399
+ OP_REQUIRES (
400
+ ctx, accum_update.IsInitialized (),
401
+ errors::FailedPrecondition (
402
+ " Attempting to use uninitialized variables: " , def ().input (2 )));
403
+ OP_REQUIRES (
404
+ ctx, var.shape ().IsSameSize (accum_grad.shape ()),
405
+ errors::InvalidArgument (" var and accum_grad do not have the same shape" ,
406
+ var.shape ().DebugString (), " " ,
407
+ accum_grad.shape ().DebugString ()));
408
+ OP_REQUIRES (
409
+ ctx, var.shape ().IsSameSize (accum_update.shape ()),
410
+ errors::InvalidArgument (" var and accum_update do not have the same shape" ,
411
+ var.shape ().DebugString (), " " ,
412
+ accum_update.shape ().DebugString ()));
413
+ OP_REQUIRES (ctx, TensorShapeUtils::IsVectorOrHigher (var.shape ()),
414
+ errors::InvalidArgument (" var must be at least 1 dimensional" ));
415
+
416
+ const Tensor& lr = ctx->input (3 );
417
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (lr.shape ()),
418
+ errors::InvalidArgument (" lr is not a scalar: " ,
419
+ lr.shape ().DebugString ()));
420
+ const Tensor& rho = ctx->input (4 );
421
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (rho.shape ()),
422
+ errors::InvalidArgument (" rho is not a scalar: " ,
423
+ rho.shape ().DebugString ()));
424
+ const Tensor& epsilon = ctx->input (5 );
425
+ OP_REQUIRES (ctx, TensorShapeUtils::IsScalar (epsilon.shape ()),
426
+ errors::InvalidArgument (" epsilon is not a scalar: " ,
427
+ epsilon.shape ().DebugString ()));
428
+ const Tensor& grad = ctx->input (6 );
429
+ const Tensor& indices = ctx->input (7 );
430
+ OP_REQUIRES (ctx, TensorShapeUtils::IsVector (indices.shape ()),
431
+ errors::InvalidArgument (" indices must be one-dimensional" ));
432
+
433
+ for (int d = 1 ; d < var.dims (); d++) {
434
+ OP_REQUIRES (ctx, var.dim_size (d) == grad.dim_size (d),
435
+ errors::InvalidArgument (strings::StrCat (
436
+ " var and grad must match in dimension " , d)));
437
+ }
438
+ const Tindex N = indices.dim_size (0 );
439
+ OP_REQUIRES (
440
+ ctx, grad.dim_size (0 ) == N,
441
+ errors::InvalidArgument (
442
+ " grad must be the same size as indices in the first dimension." ));
443
+
444
+ if (N > 0 ) {
445
+ const Tindex first_dim_size = var.dim_size (0 );
446
+ // Validate all the indices are in range
447
+ auto indices_vec = indices.vec <Tindex>();
448
+ for (Tindex i = 0 ; i < N; i++) {
449
+ const Tindex index = indices_vec (i);
450
+ OP_REQUIRES (ctx, index >= 0 && index < first_dim_size,
451
+ errors::InvalidArgument (
452
+ strings::StrCat (" Index " , index , " at offset " , i,
453
+ " in indices is out of range" )));
454
+ }
455
+
456
+ auto var_flat = var.flat_outer_dims <T>();
457
+ auto accum_grad_flat = accum_grad.flat_outer_dims <T>();
458
+ auto accum_update_flat = accum_update.flat_outer_dims <T>();
459
+ auto grad_flat = grad.flat_outer_dims <T>();
460
+ const T lr_scalar = lr.scalar <T>()();
461
+ const T rho_scalar = rho.scalar <T>()();
462
+ const T epsilon_scalar = epsilon.scalar <T>()();
463
+
464
+ for (Tindex i = 0 ; i < N; i++) {
465
+ const Tindex index = indices_vec (i);
466
+ auto accum_ = accum_grad_flat.template chip <0 >(index );
467
+ auto accum_update_ = accum_update_flat.template chip <0 >(index );
468
+ auto grad_ = grad_flat.template chip <0 >(i);
469
+
470
+ accum_ = accum_ * accum_.constant (rho_scalar) + grad_.square () * grad_.constant (1 - rho_scalar);
471
+ const auto update = (accum_update_ + accum_update_.constant (epsilon_scalar)).sqrt () * (accum_ + accum_.constant (epsilon_scalar)).rsqrt () * grad_;
472
+ accum_update_ = accum_update_ * accum_update_.constant (rho_scalar) + update.square () * update.constant (1 - rho_scalar);
473
+
474
+ auto v = var_flat.template chip <0 >(index );
475
+ v -= update * update.constant (lr_scalar);
476
+ }
477
+ }
478
+ if (use_exclusive_lock_) {
479
+ mu_var->unlock ();
480
+ }
481
+
482
+ ctx->forward_ref_input_to_ref_output (0 , 0 );
483
+ }
484
+
485
+ private:
486
+ bool use_exclusive_lock_;
487
+ };
488
+
489
+ #define REGISTER_KERNELS (T, Tindices ) \
490
+ REGISTER_KERNEL_BUILDER (Name(" SparseApplyAdadelta" ) \
491
+ .Device(DEVICE_CPU) \
492
+ .TypeConstraint<T>(" T" ) \
493
+ .TypeConstraint<Tindices>(" Tindices" ), \
494
+ SparseApplyAdadeltaOp<T, Tindices>);
495
+
496
+ REGISTER_KERNELS (float , int32);
497
+ REGISTER_KERNELS (float , int64);
498
+ REGISTER_KERNELS (double , int32);
499
+ REGISTER_KERNELS (double , int64);
500
+
501
+ #undef REGISTER_KERNELS
502
+
227
503
template <typename Device, typename T>
228
504
class ApplyAdagradOp : public OpKernel {
229
505
public:
0 commit comments