|
23 | 23 | K = tf.keras.backend
|
24 | 24 |
|
25 | 25 |
|
26 |
| -class LazyAdam(tf.keras.optimizers.Adam): |
27 |
| - """Variant of the Adam optimizer that handles sparse updates more efficiently. |
28 |
| -
|
29 |
| - The original Adam algorithm maintains two moving-average accumulators for |
30 |
| - each trainable variable; the accumulators are updated at every step. |
31 |
| - This class provides lazier handling of gradient updates for sparse |
32 |
| - variables. It only updates moving-average accumulators for sparse variable |
33 |
| - indices that appear in the current batch, rather than updating the |
34 |
| - accumulators for all indices. Compared with the original Adam optimizer, |
35 |
| - it can provide large improvements in model training throughput for some |
36 |
| - applications. However, it provides slightly different semantics than the |
37 |
| - original Adam algorithm, and may lead to different empirical results. |
38 |
| - Note, amsgrad is currently not supported and the argument can only be |
39 |
| - False. |
40 |
| -
|
41 |
| - This class is borrowed from: |
42 |
| - https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/lazy_adam.py |
43 |
| - """ |
44 |
| - |
45 |
| - def _resource_apply_sparse(self, grad, var, indices): |
46 |
| - """Applies grad for one step.""" |
47 |
| - var_dtype = var.dtype.base_dtype |
48 |
| - lr_t = self._decayed_lr(var_dtype) |
49 |
| - beta_1_t = self._get_hyper('beta_1', var_dtype) |
50 |
| - beta_2_t = self._get_hyper('beta_2', var_dtype) |
51 |
| - local_step = tf.cast(self.iterations + 1, var_dtype) |
52 |
| - beta_1_power = tf.math.pow(beta_1_t, local_step) |
53 |
| - beta_2_power = tf.math.pow(beta_2_t, local_step) |
54 |
| - epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) |
55 |
| - lr = (lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power)) |
56 |
| - |
57 |
| - # \\(m := beta1 * m + (1 - beta1) * g_t\\) |
58 |
| - m = self.get_slot(var, 'm') |
59 |
| - m_t_slice = beta_1_t * tf.gather(m, indices) + (1 - beta_1_t) * grad |
60 |
| - |
61 |
| - m_update_kwargs = { |
62 |
| - 'resource': m.handle, |
63 |
| - 'indices': indices, |
64 |
| - 'updates': m_t_slice |
65 |
| - } |
66 |
| - m_update_op = tf.raw_ops.ResourceScatterUpdate(**m_update_kwargs) |
67 |
| - |
68 |
| - # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) |
69 |
| - v = self.get_slot(var, 'v') |
70 |
| - v_t_slice = (beta_2_t * tf.gather(v, indices) + |
71 |
| - (1 - beta_2_t) * tf.math.square(grad)) |
72 |
| - |
73 |
| - v_update_kwargs = { |
74 |
| - 'resource': v.handle, |
75 |
| - 'indices': indices, |
76 |
| - 'updates': v_t_slice |
77 |
| - } |
78 |
| - v_update_op = tf.raw_ops.ResourceScatterUpdate(**v_update_kwargs) |
79 |
| - |
80 |
| - # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) |
81 |
| - var_slice = lr * m_t_slice / (tf.math.sqrt(v_t_slice) + epsilon_t) |
82 |
| - |
83 |
| - var_update_kwargs = { |
84 |
| - 'resource': var.handle, |
85 |
| - 'indices': indices, |
86 |
| - 'updates': var_slice |
87 |
| - } |
88 |
| - var_update_op = tf.raw_ops.ResourceScatterSub(**var_update_kwargs) |
89 |
| - |
90 |
| - return tf.group(*[var_update_op, m_update_op, v_update_op]) |
91 |
| - |
92 |
| - |
93 | 26 | class LearningRateFn(object):
|
94 | 27 | """Creates learning rate function."""
|
95 | 28 |
|
|
0 commit comments