Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e49c805
add gradient accumulator
fsx950223 Jul 15, 2021
2c0fbae
add exceptions
fsx950223 Jul 15, 2021
11e536d
fix multi gpus bug
fsx950223 Jul 15, 2021
1a4c0d4
fix test bugs
fsx950223 Jul 15, 2021
eabed95
fix sparse optimizer
fsx950223 Jul 15, 2021
a6ff7c0
remove read_value
fsx950223 Jul 15, 2021
24ae8a9
fix sparse test
fsx950223 Jul 15, 2021
2760fad
fix sparse bug
fsx950223 Jul 15, 2021
4ba7a55
refactor
fsx950223 Jul 15, 2021
dc50184
add sparse multi gpu test
fsx950223 Jul 16, 2021
8cd65ad
fix rnn bug
fsx950223 Jul 18, 2021
7d40946
fix step bugs
fsx950223 Jul 19, 2021
6949bd3
fix _iterations
fsx950223 Jul 19, 2021
9e423e5
use gradient transformer
fsx950223 Jul 19, 2021
7f3b2e9
fix bug
fsx950223 Jul 19, 2021
99dcde5
fix step bug
fsx950223 Jul 19, 2021
a184581
simpify code
fsx950223 Jul 19, 2021
d0718f8
optimize
fsx950223 Jul 19, 2021
2af5475
fix bug
fsx950223 Jul 19, 2021
42fccea
fix bug
fsx950223 Jul 19, 2021
93794ec
simpify code
fsx950223 Jul 19, 2021
e62cc95
add mean reduction
fsx950223 Jul 19, 2021
64b70b4
decrease memory usage
fsx950223 Jul 20, 2021
4dbc208
add name
fsx950223 Jul 20, 2021
b314592
Implement GA alternative
stefan-falk Jul 20, 2021
222e757
Run black formatter
stefan-falk Jul 20, 2021
1c7bb61
Merge branch 'ga' into ga-alt
stefan-falk Jul 20, 2021
5e75bd7
Update gradient_accumulator.py
stefan-falk Jul 20, 2021
b31c896
Fixing code mess up
stefan-falk Jul 20, 2021
a03643e
Add embedding test
stefan-falk Jul 20, 2021
0a8e686
Add currently failing LSTM-test
stefan-falk Jul 20, 2021
40b6e38
use default-strategy
stefan-falk Jul 20, 2021
6db2187
Use custom implementation for GA
stefan-falk Jul 20, 2021
4142f37
Some cleaning up
stefan-falk Jul 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tensorflow_addons.optimizers.lamb import LAMB
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.lookahead import Lookahead
from tensorflow_addons.optimizers.gradient_accumulator import GradientAccumulator
from tensorflow_addons.optimizers.moving_average import MovingAverage
from tensorflow_addons.optimizers.novograd import NovoGrad
from tensorflow_addons.optimizers.proximal_adagrad import ProximalAdagrad
Expand Down
105 changes: 105 additions & 0 deletions tensorflow_addons/optimizers/gradient_accumulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Union, Dict, Hashable, List

import tensorflow as tf

from tensorflow_addons.utils import types


class AccumulationGradientTransformer:
_accu_gradients: Union[Dict[Hashable, tf.Variable], None] = None

def __init__(
self,
optimizer: types.Optimizer,
accu_steps: types.TensorLike,
trainable_variables,
):
self.optimizer = optimizer
self.accu_steps = accu_steps
self.step = tf.Variable(0, dtype=tf.int64, name="ga_step")
self._accu_gradients: Union[List[tf.Variable], None] = None
policy = tf.keras.mixed_precision.global_policy()
self.variable_dtype = policy.variable_dtype
self._accu_gradients = {
var.ref(): self.optimizer.add_slot(var, "ga") for var in trainable_variables
}

def __call__(self, grads_and_vars, *args, **kwargs):

variables = [var for (_, var) in grads_and_vars]
accu_gradients = self._accu_gradients
step_inc_op = self.step.assign_add(1, read_value=False)

with tf.control_dependencies([step_inc_op]):
can_apply = tf.cast(
self.step % self.accu_steps == 0, dtype=self.variable_dtype
)
accumulate = tf.cast(
self.step % (self.accu_steps + 1) != 0, dtype=self.variable_dtype
)

accum_ops = list()
for grad, var in grads_and_vars:

# Get the accumulated gradient
grad_accum = accu_gradients[var.ref()] * accumulate

if isinstance(grad, tf.IndexedSlices):
# Not sure why e.g. the Embedding layer requires an additional dimension here
grad_indices = grad.indices[..., None] if len(grad.indices.shape) < 2 else grad.indices
added = tf.IndexedSlices(
values=grad.values
+ tf.gather_nd(grad_accum, grad_indices),
indices=grad.indices,
dense_shape=grad.dense_shape,
)
accu_op = accu_gradients[var.ref()].scatter_update(added)
else:
accu_op = accu_gradients[var.ref()].assign(
grad + grad_accum, read_value=False
)

accum_ops.append(accu_op)

iter_dec_op = self.optimizer.iterations.assign_add(
-1 * tf.cast(can_apply, dtype=self.optimizer.iterations.dtype),
read_value=False,
)

with tf.control_dependencies(accum_ops + [iter_dec_op]):
gradients = [accu_gradients[var.ref()] * can_apply for var in variables]
return list(zip(gradients, variables))


def GradientAccumulator(
optimizer: types.Optimizer, trainable_variables, accu_steps: int = 2
) -> types.Optimizer:
if trainable_variables is None:
trainable_variables = list()

if isinstance(optimizer, str):
optimizer = tf.keras.optimizers.get(optimizer)

optimizer.gradient_transformers.append(
AccumulationGradientTransformer(
optimizer=optimizer,
accu_steps=accu_steps,
trainable_variables=trainable_variables,
)
)

return optimizer
Loading