Skip to content

Unable to use gradient-accumulation with mixed-precision #14829

@stefan-falk

Description

@stefan-falk

Overview

  • Python: 3.9.0
  • Tensorflow: 2.5.0 (v2.5.0-rc3-213-ga4dfb8d1a7)
  • CUDA: 11.2

Description

I was able to implement gradient-accumulation by patching an arbitrary optimizers apply_gradients() function with the following replacement below.

Just to explain the basic idea here: Every n-steps the apply signal will be 1.0 and therefore all gradients will remain untouched, hence calling the _orig_apply_gradients()-function will simply update the weights.

Whenever this is not the case, the apply-signal will be 0.0 and all gradients will be set to zero, resulting in an update which does not have any effect. This is just a workaround because I wasn't able to use tf.cond() in this context and I am not aware of any other way to do this as for now.

def apply_gradients(self, grads_and_vars, *args, **kwargs):
    can_apply = self._can_apply_on_next_step()
    # 1.0 whenever we want to apply gradients; 0.0 otherwise
    apply = tf.cast(can_apply, dtype=self.variable_dtype)
    # Will be 0.0 if apply is 1.0 and vice versa
    keep = tf.cast(tf.logical_not(can_apply), dtype=self.variable_dtype)

    grads_and_vars = list(grads_and_vars)
    gradients = [grad for (grad, _) in grads_and_vars]
    trainable_variables = [var for (_, var) in grads_and_vars]

    accu_gradients = self._get_accu_gradients_for(gradients)

    # Accumulate gradients
    for i, grad in enumerate(gradients):
        accu_gradients[i].assign_add(grad / tf.cast(self.n, dtype=grad.dtype))

    # Multiply each gradient with our apply-signal
    final_gradients = [grad * apply for grad in accu_gradients]

    # Call the original apply_gradients() function
    self._orig_apply_gradients(zip(final_gradients, trainable_variables), *args, **kwargs)

    # Undo the increment of iterations whenever we did not apply gradients
    self.iterations.assign_add(-1 * tf.cast(keep, dtype=self.iterations.dtype))

    # This will reset our buffer whenever "keep" is 0.0
    for g in accu_gradients:
        g.assign(g * keep)

My problem currently is that I am not able to use this patch in combination with a mixed_float16 mixed-precision policy and I was hoping somebody could tell me how I might be able to fix this issue.

In this case I would receive a ValueError:

ValueError: Lengths of branch outputs of cond must match.
len(graphs[0].outputs): 0
len(graphs[1].outputs): 1

MNIST example

I have implemented a runnable MNIST example which you can try out below. I am able to perform a training with a float32-policy but not with mixed_float16.

Runnable MNIST example (click me)
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision


class GradientAccumulationOptimizer:
    __create_key = object()

    def __init__(self, create_key, optimizer: keras.optimizers.Optimizer, n: int):
        """Optimizer patch for gradient accumulation.
        :param create_key:
            Object create key for private constructor.
        :param n:
            The number of accumulation steps.
        :param optimizer:
            The unpatched optimizer.
        """
        if create_key != GradientAccumulationOptimizer.__create_key:
            raise RuntimeError(
                'Don\'t call constructor directly. Call GradientAccumulationOptimizerPatch.patch instead.'
            )

        self.n = tf.constant(n, dtype=tf.int64)
        policy = tf.keras.mixed_precision.global_policy()
        self.variable_dtype = policy.variable_dtype
        self._optimizer = optimizer
        self._accu_gradients = None
        self._current_step = tf.Variable(0, dtype=tf.int64)
        self._orig_apply_gradients = optimizer.apply_gradients

    @property
    def iterations(self):
        return self._optimizer.iterations

    def apply_gradients(self, grads_and_vars, *args, **kwargs):
        """Applies gradients in a gradient accumulating fashion.
        :param grads_and_vars:
            Gradients to optimize and model variables from the training batch.
        :param args:
            Additional args for the internal optimizer.
        :param kwargs:
            Additional kwargs for the internal optimizer.
        """

        can_apply = self._can_apply_on_next_step()
        # 1.0 whenever we want to apply gradients; 0.0 otherwise
        apply = tf.cast(can_apply, dtype=self.variable_dtype)
        # Will be 0.0 if apply is 1.0 and vice versa
        keep = tf.cast(tf.logical_not(can_apply), dtype=self.variable_dtype)

        grads_and_vars = list(grads_and_vars)
        gradients = [grad for (grad, _) in grads_and_vars]
        trainable_variables = [var for (_, var) in grads_and_vars]

        accu_gradients = self._get_accu_gradients_for(gradients)

        # Accumulate gradients
        for i, grad in enumerate(gradients):
            accu_gradients[i].assign_add(grad / tf.cast(self.n, dtype=grad.dtype))

        # Multiply each gradient with our apply-signal
        final_gradients = [grad * apply for grad in accu_gradients]

        # Call the original apply_gradients() function
        self._orig_apply_gradients(zip(final_gradients, trainable_variables), *args, **kwargs)

        # Undo the increment of iterations whenever we did not apply gradients
        self.iterations.assign_add(-1 * tf.cast(keep, dtype=self.iterations.dtype))

        # This will reset our buffer whenever "keep" is 0.0
        for g in accu_gradients:
            g.assign(g * keep)

    def _get_accu_gradients_for(self, gradients: list) -> list:
        """Returns the accumulator gradients for a given set of actual gradients."""
        if self._accu_gradients is None:
            self._accu_gradients = [
                tf.Variable(
                    tf.zeros_like(g),
                    trainable=False,
                    synchronization=tf.VariableSynchronization.ON_READ
                ) for g in gradients
            ]
        return self._accu_gradients

    def _can_apply_on_next_step(self):
        """Should be called only once in apply_gradients().
        :return: True if gradients should be applied; False otherwise.
        """
        # Increment (always do this first)
        self._current_step.assign_add(1)
        count_mod_steps = tf.math.mod(self._current_step, self.n)
        return tf.equal(count_mod_steps, 0)

    @classmethod
    def patch(cls, optimizer: tf.keras.optimizers.Optimizer, n: int) -> tf.keras.optimizers.Optimizer:
        """Patch optimizer for gradient accumulation.

        :param optimizer:
            The optimizer to patch.
        :param n:
            The number of accumulation steps before applying gradients.
        :return:
            A patched patched optimizer for gradient accumulation.
        """
        # policy = tf.keras.mixed_precision.global_policy()
        # if isinstance(optimizer, keras.mixed_precision.LossScaleOptimizer) or policy.name.startswith('mixed'):
        #     raise RuntimeError('Mixed-precision is not supported yet in combination with gradient accumulation.')

        accumulator = cls(cls.__create_key, n=n, optimizer=optimizer)
        optimizer.apply_gradients = accumulator.apply_gradients
        return optimizer


def get_ffn_model(input_size: int, output_size: int, hidden_size: int = 64) -> keras.Model:
    inputs = layers.Input(shape=(input_size,))
    x = inputs
    x = layers.Dense(units=hidden_size, activation='tanh')(x)
    x = layers.Dense(units=hidden_size, activation='tanh')(x)
    x = layers.Dense(units=output_size, activation='softmax')(x)
    return keras.Model(inputs=inputs, outputs=x)


def make_dataset(inputs, targets, batch_size: int, split: str, limit: int = None):
    def sample_generator_():
        while True:
            idx = np.random.randint(0, len(inputs))
            yield inputs[idx].flatten(), tf.one_hot(targets[idx], depth=num_classes)

    assert split in ('train', 'test', 'dev'), \
        f'Split must be one of "train", "test" or "dev". Got: {split}'

    inputs = inputs.astype(np.float32) / 255.0
    inputs = np.expand_dims(inputs, axis=-1)
    num_classes = len(set(targets))

    input_shape = (np.prod(inputs[0].shape),)
    target_shape = (num_classes,)

    dataset = tf.data.Dataset.from_generator(
        lambda: sample_generator_(),
        output_types=(tf.float32, tf.float32),
        output_shapes=(input_shape, target_shape)
    )

    is_training = split == 'train'

    if is_training:
        dataset = dataset.repeat()

    if limit:
        dataset = dataset.take(limit)

    return dataset.padded_batch(batch_size)


def main():
    train_batch_size = 1
    valid_batch_size = 10
    grad_acc_n = 4
    steps_per_epoch = 1000 * grad_acc_n  # Make sure we have the same number of updates

    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    train_data = make_dataset(x_train, y_train, batch_size=train_batch_size, split='train')
    valid_data = make_dataset(x_test, y_test, batch_size=valid_batch_size, split='dev', limit=500)

    input_size = train_data.element_spec[0].shape[-1]
    output_size = train_data.element_spec[1].shape[-1]

    epochs = 2

    for precision_policy in ['float32', 'mixed_float16']:
        print('#' * 72)
        print(f'Setting precision-policy to "{precision_policy}"')

        mixed_precision.set_global_policy(precision_policy)

        with tf.distribute.get_strategy().scope():
            model = get_ffn_model(input_size=input_size, output_size=output_size, hidden_size=8)
            optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
            optimizer = GradientAccumulationOptimizer.patch(optimizer, n=grad_acc_n)

            # This is not necessary because the optimizer will be wrapped by Keras see
            # https://github.com/tensorflow/tensorflow/blob/e2af7f7927655e1d0b048bed05afa5e5be8c1f9f/tensorflow/python/keras/engine/training.py#L593

            # if precision_policy.startswith('mixed'):
            #     print(f'Using LossScaleOptimizer for precision-policy "{precision_policy}"')
            #     optimizer = mixed_precision.LossScaleOptimizer(optimizer)

        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )

        model.fit(
            train_data,
            epochs=epochs,
            steps_per_epoch=steps_per_epoch // train_batch_size,
            validation_data=valid_data,
            validation_steps=10
        )

        loss, accuracy = model.evaluate(valid_data)

        print(f'Evaluation')
        print(f'  - Loss:     {loss:.4f}')
        print(f'  - Accuracy: {accuracy:.4f}')


if __name__ == '__main__':
    main()

Full error log

Running a training with the patched optimizer will result in the following error:

ValueError: in user code:

    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:855 train_function  *
        return step_function(self, iterator)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:845 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:1285 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:2833 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/distribute/distribute_lib.py:3608 _call_for_each_replica
        return fn(*args, **kwargs)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:838 run_step  **
        outputs = model.train_step(data)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/engine/training.py:799 train_step
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:530 minimize
        return self.apply_gradients(grads_and_vars, name=name)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/keras/mixed_precision/loss_scale_optimizer.py:740 apply_gradients
        maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/framework/smart_cond.py:58 smart_cond
        return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:206 wrapper
        return target(*args, **kwargs)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/util/deprecation.py:535 new_func
        return func(*args, **kwargs)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/control_flow_ops.py:1254 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/cond_v2.py:99 cond_v2
        return _build_cond(
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/cond_v2.py:226 _build_cond
        _check_same_outputs(_COND, [true_graph, false_graph])
    /home/sfalk/miniconda3/envs/asr2/lib/python3.9/site-packages/tensorflow/python/ops/cond_v2.py:796 _check_same_outputs
        raise ValueError("Lengths of branch outputs of {op_type} must match.\n"

    ValueError: Lengths of branch outputs of cond must match.
    len(graphs[0].outputs): 0
    len(graphs[1].outputs): 1

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions