Skip to content

Add dynamic loss scaler support and update #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion configs/det/db_r50_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ optimizer:
filter_bias_and_bn: false
momentum: 0.9
weight_decay: 1.0e-4
loss_scale: 1.0

# only used for mixed precision training
loss_scaler:
type: static
loss_scale: 1.0

train:
ckpt_save_dir: './tmp_det'
Expand Down
9 changes: 8 additions & 1 deletion configs/rec/crnn/crnn_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,16 @@ optimizer:
filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.0001
loss_scale: 1.0
#loss_scale: 1.0
#use_nesterov: True

# only used for mixed precision training
loss_scaler:
type: dynamic
loss_scale: 1.0
scale_factor: 2.0
scale_window: 2000

train:
ckpt_save_dir: './tmp_rec'
dataset_sink_mode: False
Expand Down
5 changes: 4 additions & 1 deletion configs/rec/crnn/crnn_resnet34.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ optimizer:
filter_bias_and_bn: True
momentum: 0.95
weight_decay: 0.0001
loss_scale: 512
nesterov: False
#use_nesterov: True

loss_scaler:
type: static
loss_scale: 512

train:
ckpt_save_dir: './tmp_rec'
dataset_sink_mode: False
Expand Down
5 changes: 4 additions & 1 deletion configs/rec/crnn/crnn_vgg7.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ optimizer:
filter_bias_and_bn: True
momentum: 0.95
weight_decay: 0.0001
loss_scale: 1024
nesterov: False
#use_nesterov: True

loss_scaler:
type: static
loss_scale: 1024

train:
ckpt_save_dir: './tmp_rec'
dataset_sink_mode: False
Expand Down
35 changes: 35 additions & 0 deletions mindocr/utils/loss_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from mindspore import nn


def get_loss_scales(cfg):
'''
Args:
cfg (dict): configure dict of loss scaler

Returns:
nn.Cell: scale_sens used to scale gradient
float: loss_scale used in optimizer (only used when loss scaler type is static and drop_overflow update is False)
'''
# loss scale is 1.0 by default
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=1.0)
optimizer_loss_scale = 1.0

if 'loss_scaler' in cfg:
assert 'loss_scale' in cfg.loss_scaler, 'Must specify the value for `loss_scale` in the config if `loss_scaler` is used.'
if cfg.loss_scaler.type == 'dynamic':
# TODO: scale_window can be related to num_batches, e.g., scale_window = num_batches * 2
loss_scale_manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scaler.get('loss_scale', 2**16),
scale_factor=cfg.loss_scaler.get('scale_factor', 2.0),
scale_window=cfg.loss_scaler.get('scale_window', 2000),
)
elif cfg.loss_scaler.type == 'static':
loss_scale = cfg.loss_scaler.get('loss_scale', 1.0)
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale)
# when using static loss scaler and drop_overflow_update is False, we should also set loss_scale for optimizer.
if not cfg.system.drop_overflow_update:
optimizer_loss_scale = loss_scale
else:
raise ValueError(f'Available loss scaler types are `static` and `dynamic`, but got {cfg.loss_scaler}')

return loss_scale_manager, optimizer_loss_scale

10 changes: 7 additions & 3 deletions mindocr/utils/train_step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class TrainOneStepWrapper(nn.TrainOneStepWithLossScaleCell):
"""TrainStep with ema and clip grad.
Args:
drop_overflow_update: if True, network will not be updated when gradient is overflow.
scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called
to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`,
the shape should be :math:`()` or :math:`(1,)`.

"""

def __init__(
Expand Down Expand Up @@ -85,7 +89,7 @@ def construct(self, *inputs):
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
else:
status = None

scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
Expand All @@ -97,8 +101,8 @@ def construct(self, *inputs):
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
else:
overflow = False
cond = False
overflow = ms.Tensor(False)
cond = ms.Tensor(False)

if self.drop_overflow_update:
# if there is no overflow, do optimize
Expand Down
2 changes: 1 addition & 1 deletion tests/st/test_train_eval_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


@pytest.mark.parametrize("task", ["det", "rec"])
@pytest.mark.parametrize("val_while_train", [True, False])
@pytest.mark.parametrize("val_while_train", [False, True])
def test_train_eval(task, val_while_train):
# prepare dummy images
data_dir = "data/Canidae"
Expand Down
73 changes: 73 additions & 0 deletions tests/ut/test_loss_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import sys
sys.path.append('.')

import numpy as np
import mindspore
from mindspore import Tensor, Parameter, nn
import mindspore.ops as ops
import pytest
from mindocr.utils.loss_scaler import get_loss_scales
from addict import Dict

class Net(nn.Cell):
def __init__(self, in_features, out_features):
super(Net, self).__init__()
self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
name='weight')
self.matmul = ops.MatMul()

def construct(self, x):
output = self.matmul(x, self.weight)
return output

static_cfg = {}
static_cfg['loss_scaler'] = {'type': 'static',
'loss_scale': 1.0}
static_cfg = Dict(static_cfg)

dynamic_cfg = {}
dynamic_cfg['loss_scaler'] = { 'type': 'dynamic',
'loss_scale': 1024.0,
'scale_factor': 2.0,
'scale_window': 2}
dynamic_cfg = Dict(dynamic_cfg)


@pytest.mark.parametrize('ls_type', ['static', 'dynamic'])
@pytest.mark.parametrize('drop_overflow_update', [True, False])
def test_loss_scaler(ls_type, drop_overflow_update):
in_features, out_features = 16, 10
net = Net(in_features, out_features)
loss = nn.MSELoss()
net_with_loss = nn.WithLossCell(net, loss)

#manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
if ls_type == 'static':
cfg = static_cfg
elif ls_type == 'dynamic':
cfg = dynamic_cfg
cfg.system.drop_overflow_update = drop_overflow_update

manager, opt_loss_scale = get_loss_scales(cfg)

optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=opt_loss_scale)
train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)

input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
labels = Tensor(np.ones([out_features,]), mindspore.float32)

loss_scales = []
for i in range(3):
loss, is_overflow, loss_scale_updated = train_network(input, labels)
loss_scales.append(float(loss_scale_updated.asnumpy()))
print(loss)

print(loss_scales)

if ls_type == 'static':
assert loss_scales[0] == loss_scales[-1]
elif ls_type == 'dynamic':
assert loss_scales[0] != loss_scales[-1]

if __name__ == '__main__':
test_loss_scaler('static', True)
29 changes: 18 additions & 11 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@
from mindocr.utils.train_step_wrapper import TrainOneStepWrapper
from mindocr.utils.callbacks import EvalSaveCallback
from mindocr.utils.seed import set_seed
from mindocr.utils.loss_scaler import get_loss_scales


def main(cfg):
# env init
# init env
ms.set_context(mode=cfg.system.mode)
if cfg.system.distribute:
init()
Expand All @@ -56,8 +57,7 @@ def main(cfg):
#cv2.setNumThreads(2) # TODO: by default, num threads = num cpu cores
is_main_device = rank_id in [None, 0]

# train pipeline
# dataset
# create dataset
loader_train = build_dataset(
cfg.train.dataset,
cfg.train.loader,
Expand All @@ -76,25 +76,30 @@ def main(cfg):
shard_id=None,
is_train=False)

# model
# create model
network = build_model(cfg.model)
ms.amp.auto_mixed_precision(network, amp_level=cfg.system.amp_level)

# optimizer and sheduler (from mindcv)
lr_scheduler = create_scheduler(num_batches, **cfg['scheduler'])
optimizer = create_optimizer(network.trainable_params(), lr=lr_scheduler, **cfg['optimizer'])
# create loss
loss_fn = build_loss(cfg.loss.pop('name'), **cfg['loss'])

# wrap train-one-step cell
net_with_loss = NetWithLossWrapper(network, loss_fn)
net_with_loss = NetWithLossWrapper(network, loss_fn) # wrap train-one-step cell

# get loss scale setting for mixed precision training
loss_scale_manager, optimizer_loss_scale = get_loss_scales(cfg)

# build lr scheduler
lr_scheduler = create_scheduler(num_batches, **cfg['scheduler'])

loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=cfg.optimizer.loss_scale)
# build optimizer
cfg.optimizer.update({'lr': lr_scheduler, 'loss_scale': optimizer_loss_scale})
optimizer = create_optimizer(network.trainable_params(), **cfg.optimizer)

# build train step cell
train_net = TrainOneStepWrapper(net_with_loss,
optimizer=optimizer,
scale_sense=loss_scale_manager,
drop_overflow_update=cfg.system.drop_overflow_update,
verbose=True
)
# postprocess, metric
postprocessor = None
Expand Down Expand Up @@ -124,6 +129,8 @@ def main(cfg):
f'Optimizer: {cfg.optimizer.opt}\n'
f'Scheduler: {cfg.scheduler.scheduler}\n'
f'LR: {cfg.scheduler.lr} \n'
f'Auto mixed precision: {cfg.system.amp_level}\n'
f'Loss scale setting: {cfg.loss_scaler}\n'
f'drop_overflow_update: {cfg.system.drop_overflow_update}'
)
if 'name' in cfg.model:
Expand Down