Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support DynamicLossScale for TrainStep #678

Merged
merged 1 commit into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from .callbacks import *
from .checkpoint_manager import *
from .download import *
from .gradient_accumulation import *
from .path import *
from .random import *
from .reduce_manager import *
from .train_step import *
from .trainer_factory import *
from .utils import *
72 changes: 0 additions & 72 deletions mindcv/utils/gradient_accumulation.py

This file was deleted.

142 changes: 105 additions & 37 deletions mindcv/utils/train_step.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,98 @@
"""Ema define"""
"""Customized TrainOneStepCell.

Supported algorithms are list as follows:
* Exponential Moving Average (EMA)
* Gradient Clipping
* Gradient Accumulation
"""

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common import RowTensor
from mindspore.ops import composite as C
from mindspore import RowTensor, boost, nn, ops
from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op
from mindspore.ops import functional as F
from mindspore.ops import operations as P

from .gradient_accumulation import GradientAccumulation

__all__ = [
"GradientAccumulation",
"TrainStep",
]

_ema_op = C.MultitypeFuncGraph("grad_ema_op")
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
_ema_op = ops.MultitypeFuncGraph("ema_op")
_grad_scale = ops.MultitypeFuncGraph("grad_scale")
reciprocal = ops.Reciprocal()


@_ema_op.register("Tensor", "Tensor", "Tensor")
def _ema_weights(factor, ema_weight, weight):
def ema_ops(factor, ema_weight, weight):
return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor))


@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
def grad_scale_tensor(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))


@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
def grad_scale_row_tensor(scale, grad):
return RowTensor(
grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape,
)


class GradientAccumulation(boost.GradientAccumulation):
"""
After accumulating the gradients of multiple steps, call to optimize its update.

Note:
The implementation is based on mindspore.boost.GradientAccumulation with the following modifications:

1. The learning rate will be updated at each iteration step.
However, in the original implementation, the learning rate will only be updated for each M iteration steps.
2. For distributed training, at gradient accumulation stage, each device will maintain
its own accumulated gradient; At the parameter update stage, the gradient will be synchronized
across the devices first and then perform the gradient update.

Args:
max_accumulation_step (int): Steps to accumulate gradients.
optimizer (nn.Cell): Optimizer used.
grad_reducer (nn.Cell): Gradient reducer, which synchronize gradients across the devices.
"""

def __init__(self, max_accumulation_step, optimizer, grad_reducer):
super().__init__(max_accumulation_step, optimizer)
self.grad_reducer = grad_reducer

def construct(self, loss, grads):
loss = F.depend(
loss,
self.hyper_map(
F.partial(gradient_accumulation_op, self._max_accumulation_step), self._grad_accumulation, grads
),
)
self._accumulation_step += 1

if self._accumulation_step >= self._max_accumulation_step:
# accumulate the gradient at each device, don't sync them until updating the weight
reduced_grad_accumulation = self.grad_reducer(self._grad_accumulation)
loss = F.depend(loss, self.optimizer(reduced_grad_accumulation))
loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
self._accumulation_step = 0
else:
# update the learning rate, do not update the parameter
loss = F.depend(loss, self.optimizer.get_lr())

return loss


class TrainStep(nn.TrainOneStepWithLossScaleCell):
"""TrainStep with ema and clip grad."""
"""Training step with loss scale.

The customized trainOneStepCell also supported following algorithms:
* Exponential Moving Average (EMA)
* Gradient Clipping
* Gradient Accumulation
"""

def __init__(
self,
Expand All @@ -47,58 +101,72 @@ def __init__(
scale_sense=1.0,
ema=False,
ema_decay=0.9999,
updates=0,
clip_grad=False,
clip_value=15.0,
gradient_accumulation_steps=1,
):
super(TrainStep, self).__init__(network, optimizer, scale_sense)
self.ema = ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
self.clip_grad = clip_grad
self.clip_value = clip_value
if self.ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")

self.need_accumulate_grad = gradient_accumulation_steps > 1
if self.need_accumulate_grad:
self.accumulate_grad = gradient_accumulation_steps > 1
if self.accumulate_grad:
self.gradient_accumulation = GradientAccumulation(gradient_accumulation_steps, optimizer, self.grad_reducer)

def ema_update(self):
"""Update EMA parameters."""
self.updates += 1
d = self.ema_decay * (1 - F.exp(-self.updates / 2000))
# ema factor is corrected by (1 - exp(-t/T)), where `t` means time and `T` means temperature.
ema_decay = self.ema_decay * (1 - F.exp(-self.optimizer.global_step / 2000))
# update trainable parameters
success = self.hyper_map(F.partial(_ema_op, d), self.ema_weight, self.weights_all)
self.updates = F.depend(self.updates, success)
return self.updates
success = self.hyper_map(F.partial(_ema_op, ema_decay), self.ema_weight, self.weights_all)
return success

def construct(self, *inputs):
"""construct"""
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense

status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
scaling_sens_filled = ops.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)

# todo: When to clip grad? Do we need to clip grad after grad reduction? What if grad accumulation is needed?
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)

if self.need_accumulate_grad:
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
loss = self.gradient_accumulation(loss, grads, overflow)
else:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
if self.loss_scaling_manager: # scale_sense = update_cell: Cell --> TrainOneStepWithLossScaleCell.construct
if self.accumulate_grad:
# todo: GradientAccumulation only call grad_reducer at the step where the accumulation is completed.
# So checking the overflow status is after gradient reduction, is this correct?
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = self.gradient_accumulation(loss, grads)
else:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = F.depend(loss, self.optimizer(grads))
else: # scale_sense = loss_scale: Tensor --> TrainOneStepCell.construct
if self.accumulate_grad:
loss = self.gradient_accumulation(loss, grads)
else:
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))

if self.ema:
self.ema_update()
loss = F.depend(loss, self.ema_update())

return loss
39 changes: 23 additions & 16 deletions mindcv/utils/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from typing import Union

import mindspore as ms
from mindspore import nn
from mindspore import Tensor, context, nn
from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model

from .train_step import TrainStep

__all__ = [
"get_metrics",
"require_customized_train_step",
"create_trainer",
]

Expand All @@ -28,9 +29,7 @@ def get_metrics(num_classes):
return metrics


def _require_customized_train_step(
ema: bool = False, clip_grad: bool = False, gradient_accumulation_steps: bool = False
):
def require_customized_train_step(ema: bool = False, clip_grad: bool = False, gradient_accumulation_steps: int = 1):
if ema:
return True
if clip_grad:
Expand Down Expand Up @@ -85,7 +84,7 @@ def create_trainer(
if gradient_accumulation_steps < 1:
raise ValueError("`gradient_accumulation_steps` must be >= 1!")

if not _require_customized_train_step(ema, clip_grad, gradient_accumulation_steps):
if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps):
mindspore_kwargs = dict(
network=network,
loss_fn=loss,
Expand Down Expand Up @@ -124,20 +123,28 @@ def create_trainer(
gradient_accumulation_steps=gradient_accumulation_steps,
)
if loss_scale_type.lower() == "fixed":
# todo: drop_overflow_update. If drop_overflow_update is False, scale_sense should be a number
# instead of cell, and TrainStep should be TrainOneStepCell. If drop_overflow_update is True,
# scale_sense should be FixedLossScaleUpdateCell, and TrainStep should be TrainOneStepWithLossScaleCell.
train_step_kwargs["scale_sense"] = nn.FixedLossScaleUpdateCell(loss_scale_value=loss_scale)
loss_scale_manager = FixedLossScaleManager(loss_scale=loss_scale, drop_overflow_update=drop_overflow_update)
elif loss_scale_type.lower() == "dynamic":
train_step_kwargs["scale_sense"] = nn.DynamicLossScaleUpdateCell(
loss_scale_value=loss_scale, scale_factor=2, scale_window=2000
)
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=loss_scale, scale_factor=2, scale_window=2000)
else:
raise ValueError(f"Loss scale type only support ['fixed', 'dynamic'], but got{loss_scale_type}.")
# todo: remove this check when TrainStep support dynamic loss scale and dropping overflow
if drop_overflow_update or loss_scale_type.lower() != "fixed":
raise ValueError("TrainStep only support fixed loss scale without dropping overflow!")
train_step_cell = TrainStep(**train_step_kwargs)
update_cell = loss_scale_manager.get_update_cell()
# 1. loss_scale_type="fixed", drop_overflow_update=False
# --> update_cell=None, TrainStep=TrainOneStepCell(scale_sense=loss_scale)
# 2. loss_scale_type: fixed, drop_overflow_update: True
# --> update_cell=FixedLossScaleUpdateCell, TrainStep=TrainOneStepWithLossScaleCell(scale_sense=update_cell)
# 3. loss_scale_type: dynamic, drop_overflow_update: True
# --> update_cell=DynamicLossScaleUpdateCell, TrainStep=TrainOneStepWithLossScaleCell(scale_sense=update_cell)
if update_cell is None:
train_step_kwargs["scale_sense"] = Tensor(loss_scale, dtype=ms.float32)
else:
if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
raise ValueError(
"Only `loss_scale_type` is `fixed` and `drop_overflow_update` is `False`"
"are supported on device `CPU`."
)
train_step_kwargs["scale_sense"] = update_cell
train_step_cell = TrainStep(**train_step_kwargs).set_train()
eval_network = nn.WithEvalCell(network, loss, amp_level in ["O2", "O3", "auto"])
model = Model(train_step_cell, eval_network=eval_network, metrics=metrics, eval_indexes=[0, 1, 2])
# todo: do we need to set model._loss_scale_manager
Expand Down
15 changes: 0 additions & 15 deletions mindcv/utils/utils.py

This file was deleted.

Loading