Skip to content

Commit

Permalink
feat: support DynamicLossScale for TrainStep
Browse files Browse the repository at this point in the history
  • Loading branch information
geniuspatrick committed Jun 9, 2023
1 parent 7573837 commit 5a52bd2
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 145 deletions.
2 changes: 0 additions & 2 deletions mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from .callbacks import *
from .checkpoint_manager import *
from .download import *
from .gradient_accumulation import *
from .path import *
from .random import *
from .reduce_manager import *
from .train_step import *
from .trainer_factory import *
from .utils import *
72 changes: 0 additions & 72 deletions mindcv/utils/gradient_accumulation.py

This file was deleted.

142 changes: 105 additions & 37 deletions mindcv/utils/train_step.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,98 @@
"""Ema define"""
"""Customized TrainOneStepCell.
Supported algorithms are list as follows:
* Exponential Moving Average (EMA)
* Gradient Clipping
* Gradient Accumulation
"""

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common import RowTensor
from mindspore.ops import composite as C
from mindspore import RowTensor, boost, nn, ops
from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op
from mindspore.ops import functional as F
from mindspore.ops import operations as P

from .gradient_accumulation import GradientAccumulation

__all__ = [
"GradientAccumulation",
"TrainStep",
]

_ema_op = C.MultitypeFuncGraph("grad_ema_op")
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
_ema_op = ops.MultitypeFuncGraph("ema_op")
_grad_scale = ops.MultitypeFuncGraph("grad_scale")
reciprocal = ops.Reciprocal()


@_ema_op.register("Tensor", "Tensor", "Tensor")
def _ema_weights(factor, ema_weight, weight):
def ema_ops(factor, ema_weight, weight):
return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor))


@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
def grad_scale_tensor(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))


@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
def grad_scale_row_tensor(scale, grad):
return RowTensor(
grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape,
)


class GradientAccumulation(boost.GradientAccumulation):
"""
After accumulating the gradients of multiple steps, call to optimize its update.
Note:
The implementation is based on mindspore.boost.GradientAccumulation with the following modifications:
1. The learning rate will be updated at each iteration step.
However, in the original implementation, the learning rate will only be updated for each M iteration steps.
2. For distributed training, at gradient accumulation stage, each device will maintain
its own accumulated gradient; At the parameter update stage, the gradient will be synchronized
across the devices first and then perform the gradient update.
Args:
max_accumulation_step (int): Steps to accumulate gradients.
optimizer (nn.Cell): Optimizer used.
grad_reducer (nn.Cell): Gradient reducer, which synchronize gradients across the devices.
"""

def __init__(self, max_accumulation_step, optimizer, grad_reducer):
super().__init__(max_accumulation_step, optimizer)
self.grad_reducer = grad_reducer

def construct(self, loss, grads):
loss = F.depend(
loss,
self.hyper_map(
F.partial(gradient_accumulation_op, self._max_accumulation_step), self._grad_accumulation, grads
),
)
self._accumulation_step += 1

if self._accumulation_step >= self._max_accumulation_step:
# accumulate the gradient at each device, don't sync them until updating the weight
reduced_grad_accumulation = self.grad_reducer(self._grad_accumulation)
loss = F.depend(loss, self.optimizer(reduced_grad_accumulation))
loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
self._accumulation_step = 0
else:
# update the learning rate, do not update the parameter
loss = F.depend(loss, self.optimizer.get_lr())

return loss


class TrainStep(nn.TrainOneStepWithLossScaleCell):
"""TrainStep with ema and clip grad."""
"""Training step with loss scale.
The customized trainOneStepCell also supported following algorithms:
* Exponential Moving Average (EMA)
* Gradient Clipping
* Gradient Accumulation
"""

def __init__(
self,
Expand All @@ -47,58 +101,72 @@ def __init__(
scale_sense=1.0,
ema=False,
ema_decay=0.9999,
updates=0,
clip_grad=False,
clip_value=15.0,
gradient_accumulation_steps=1,
):
super(TrainStep, self).__init__(network, optimizer, scale_sense)
self.ema = ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
self.clip_grad = clip_grad
self.clip_value = clip_value
if self.ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")

self.need_accumulate_grad = gradient_accumulation_steps > 1
if self.need_accumulate_grad:
self.accumulate_grad = gradient_accumulation_steps > 1
if self.accumulate_grad:
self.gradient_accumulation = GradientAccumulation(gradient_accumulation_steps, optimizer, self.grad_reducer)

def ema_update(self):
"""Update EMA parameters."""
self.updates += 1
d = self.ema_decay * (1 - F.exp(-self.updates / 2000))
# ema factor is corrected by (1 - exp(-t/T)), where `t` means time and `T` means temperature.
ema_decay = self.ema_decay * (1 - F.exp(-self.optimizer.global_step / 2000))
# update trainable parameters
success = self.hyper_map(F.partial(_ema_op, d), self.ema_weight, self.weights_all)
self.updates = F.depend(self.updates, success)
return self.updates
success = self.hyper_map(F.partial(_ema_op, ema_decay), self.ema_weight, self.weights_all)
return success

def construct(self, *inputs):
"""construct"""
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense

status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
scaling_sens_filled = ops.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)

# todo: When to clip grad? Do we need to clip grad after grad reduction? What if grad accumulation is needed?
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)

if self.need_accumulate_grad:
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
loss = self.gradient_accumulation(loss, grads, overflow)
else:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
if self.loss_scaling_manager: # scale_sense = update_cell: Cell --> TrainOneStepWithLossScaleCell.construct
if self.accumulate_grad:
# todo: GradientAccumulation only call grad_reducer at the step where the accumulation is completed.
# So checking the overflow status is after gradient reduction, is this correct?
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = self.gradient_accumulation(loss, grads)
else:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = F.depend(loss, self.optimizer(grads))
else: # scale_sense = loss_scale: Tensor --> TrainOneStepCell.construct
if self.accumulate_grad:
loss = self.gradient_accumulation(loss, grads)
else:
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))

if self.ema:
self.ema_update()
loss = F.depend(loss, self.ema_update())

return loss
39 changes: 23 additions & 16 deletions mindcv/utils/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import Union

import mindspore as ms
from mindspore import nn
from mindspore import Tensor, context, nn
from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model

from .train_step import TrainStep

__all__ = [
"get_metrics",
"require_customized_train_step",
"create_trainer",
]

Expand All @@ -28,9 +29,7 @@ def get_metrics(num_classes):
return metrics


def _require_customized_train_step(
ema: bool = False, clip_grad: bool = False, gradient_accumulation_steps: bool = False
):
def require_customized_train_step(ema: bool = False, clip_grad: bool = False, gradient_accumulation_steps: int = 1):
if ema:
return True
if clip_grad:
Expand Down Expand Up @@ -85,7 +84,7 @@ def create_trainer(
if gradient_accumulation_steps < 1:
raise ValueError("`gradient_accumulation_steps` must be >= 1!")

if not _require_customized_train_step(ema, clip_grad, gradient_accumulation_steps):
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps):
mindspore_kwargs = dict(
network=network,
loss_fn=loss,
Expand Down Expand Up @@ -124,20 +123,28 @@ def create_trainer(
gradient_accumulation_steps=gradient_accumulation_steps,
)
if loss_scale_type.lower() == "fixed":
# todo: drop_overflow_update. If drop_overflow_update is False, scale_sense should be a number
# instead of cell, and TrainStep should be TrainOneStepCell. If drop_overflow_update is True,
# scale_sense should be FixedLossScaleUpdateCell, and TrainStep should be TrainOneStepWithLossScaleCell.
train_step_kwargs["scale_sense"] = nn.FixedLossScaleUpdateCell(loss_scale_value=loss_scale)
loss_scale_manager = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=drop_overflow_update)
elif loss_scale_type.lower() == "dynamic":
train_step_kwargs["scale_sense"] = nn.DynamicLossScaleUpdateCell(
loss_scale_value=loss_scale, scale_factor=2, scale_window=2000
)
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=loss_scale, scale_factor=2, scale_window=2000)
else:
raise ValueError(f"Loss scale type only support ['fixed', 'dynamic'], but got{loss_scale_type}.")
# todo: remove this check when TrainStep support dynamic loss scale and dropping overflow
if drop_overflow_update or loss_scale_type.lower() != "fixed":
raise ValueError("TrainStep only support fixed loss scale without dropping overflow!")
train_step_cell = TrainStep(**train_step_kwargs)
update_cell = loss_scale_manager.get_update_cell()
# 1. loss_scale_type="fixed", drop_overflow_update=False
# --> update_cell=None, TrainStep=TrainOneStepCell(scale_sense=loss_scale)
# 2. loss_scale_type: fixed, drop_overflow_update: True
# --> update_cell=FixedLossScaleUpdateCell, TrainStep=TrainOneStepWithLossScaleCell(scale_sense=update_cell)
# 3. loss_scale_type: dynamic, drop_overflow_update: True
# --> update_cell=DynamicLossScaleUpdateCell, TrainStep=TrainOneStepWithLossScaleCell(scale_sense=update_cell)
if update_cell is None:
train_step_kwargs["scale_sense"] = Tensor(loss_scale, dtype=ms.float32)
else:
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
raise ValueError(
"Only `loss_scale_type` is `fixed` and `drop_overflow_update` is `False`"
"are supported on device `CPU`."
)
train_step_kwargs["scale_sense"] = update_cell
train_step_cell = TrainStep(**train_step_kwargs).set_train()
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2])
# todo: do we need to set model._loss_scale_manager
Expand Down
15 changes: 0 additions & 15 deletions mindcv/utils/utils.py

This file was deleted.

Loading

0 comments on commit 5a52bd2

Please sign in to comment.