Skip to content

Add gradient accumulation, gradient clip, and add tests #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions configs/rec/crnn/crnn_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@ optimizer:
# only used for mixed precision training
loss_scaler:
type: dynamic
loss_scale: 1.0
loss_scale: 512
scale_factor: 2.0
scale_window: 2000
scale_window: 1000

train:
gradient_accumulation_steps: 2
clip_grad: True
clip_norm: 0.0001
ckpt_save_dir: './tmp_rec'
dataset_sink_mode: False
dataset:
Expand Down
16 changes: 8 additions & 8 deletions mindocr/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def build_dataset(

Args:
dataset_config (dict): dataset parsing and processing configuartion containing the following keys
- type (str): dataset class name, please choose from `supported_dataset_types`.
- type (str): dataset class name, please choose from `supported_dataset_types`.
- dataset_root (str): the root directory to store the (multiple) dataset(s)
- data_dir (Union[str, List[str]]): directory to the data, which is a subfolder path related to `dataset_root`. For multiple datasets, it is a list of subfolder paths.
- label_file (Union[str, List[str]], *optional*): file path to the annotation related to the `dataset_root`. For multiple datasets, it is a list of relative file paths. Not required if using LMDBDataset.
Expand All @@ -41,8 +41,8 @@ def build_dataset(
num_shards (int, *optional*): num of devices for distributed mode
shard_id (int, *optional*): device id
is_train (boolean): whether it is in training stage
**kwargs: optional args for extension. If `refine_batch_size=True` is given in kwargs, the batch size will be refined to be divisable to avoid
droping remainding data samples in graph model, typically used for precise evaluation.
**kwargs: optional args for extension. If `refine_batch_size=True` is given in kwargs, the batch size will be refined to be divisable to avoid
droping remainding data samples in graph model, typically used for precise evaluation.

Return:
data_loader (Dataset): dataloader to generate data batch
Expand All @@ -52,8 +52,8 @@ def build_dataset(
- Each of the three steps supports multiprocess. Detailed mechanism can be seen in https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/api_python/mindspore.dataset.html
- A data row is a data tuple item containing multiple elements such as (image_i, mask_i, label_i). A data column corresponds to an element in the tuple like 'image', 'label'.
- The total number of `num_workers` used for data loading and processing should not be larger than the maximum threads of the CPU. Otherwise, it will lead to resource competing overhead. Especially for distributed training, `num_parallel_workers` should not be too large to avoid thread competition.
Example:

Example:
>>> # Load a DetDataset/RecDataset
>>> from mindocr.data import build_dataset
>>> data_config = {
Expand Down Expand Up @@ -114,7 +114,7 @@ def build_dataset(
# TODO: find optimal setting automatically according to num of CPU cores
num_workers = loader_config.get("num_workers", 8) # Number of subprocesses used to fetch the dataset/map data row/gen batch in parallel
cores = multiprocessing.cpu_count()
num_devices = 1 if num_shards is None else num_shards
num_devices = 1 if num_shards is None else num_shards
if num_workers > int(cores / num_devices):
num_workers = int(cores / num_devices)
print('WARNING: num_workers is adjusted to {num_workers}, to fit {cores} CPU cores shared for {num_devices} devices')
Expand Down Expand Up @@ -144,9 +144,9 @@ def build_dataset(
# get batch of dataset by collecting batch_size consecutive data rows and apply batch operations
num_samples = ds.get_dataset_size()
batch_size = loader_config['batch_size']
print('INFO: num_samples: {num_samples}, batch_size: {batch_size}')
print(f'INFO: num_samples: {num_samples}, batch_size: {batch_size}')
if 'refine_batch_size' in kwargs:
batch_size = _check_batch_size(num_samples, batch_size, refine=kwargs['refine_batch_size'])
batch_size = _check_batch_size(num_samples, batch_size, refine=kwargs['refine_batch_size'])

drop_remainder = loader_config.get('drop_remainder', is_train)
if is_train and drop_remainder == False:
Expand Down
3 changes: 3 additions & 0 deletions mindocr/optim/optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def create_optimizer(
# if lr is not None:
# opt_args.setdefault('lr', lr)

assert loss_scale==1.0, 'loss scale must be 1.0 in optimizer due to gradients are already scaled previously in TrainStepWrapper.'

# non-adaptive: SGD, momentum, and nesterov
if opt == "sgd":
# note: nn.Momentum may perform better if momentum > 0.
Expand Down Expand Up @@ -207,6 +209,7 @@ def _collect_args(kwargs, optim_class):
ret = {}
valid_args = list(inspect.signature(optim_class.__init__).parameters.keys())[1:] # remove self
for arg in valid_args:
assert arg != 'clip', ValueError('Gradient clipping should not be set in `optimizer`. Please set it in `train`.')
if arg in kwargs:
ret[arg] = kwargs[arg]
return ret
Expand Down
9 changes: 8 additions & 1 deletion mindocr/optim/param_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,17 @@ def create_group_params(params, weight_decay=0, grouping_strategy=None, no_weigh
Return:
list[dict], grouped parameters
'''

# TODO: assert valid arg names
gp = grouping_strategy

print(f'INFO: param grouping startegy: {grouping_strategy}, no_weight_decay_params: ', no_weight_decay_params)
if gp is not None:
if weight_decay == 0:
print("WARNING: weight decay is 0 in param grouping.")
print("WARNING: weight decay is 0 in param grouping, which is meaningless. Please check config setting.")
if len(no_weight_decay_params) > 0:
print("WARNING: Both grouping_strategy and no_weight_decay_params are set, but grouping_strategy is of prior. no_weight_decay_params={no_weight_decay_params} will not make effect.")

if gp == 'svtr':
return grouping_svtr(params, weight_decay)
elif gp == 'filter_norm_and_bias':
Expand Down
39 changes: 25 additions & 14 deletions mindocr/utils/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,42 @@ def get_loss_scales(cfg):
'''
Args:
cfg (dict): configure dict of loss scaler

Returns:
nn.Cell: scale_sens used to scale gradient
float: loss_scale used in optimizer (only used when loss scaler type is static and drop_overflow update is False)
nn.Cell: scale_sens used to scale gradient
float: loss_scale used in optimizer (only used when loss scaler type is static and drop_overflow update is False)
'''
# loss scale is 1.0 by default
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=1.0)
optimizer_loss_scale = 1.0

if 'loss_scaler' in cfg:

# Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
# `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in `FixedLossScaleManager`
# But we never use FixedLossScaleManager, so optimizer_loss_scale is always 1.
optimizer_loss_scale = 1.0


if 'loss_scaler' in cfg:
assert 'loss_scale' in cfg.loss_scaler, 'Must specify the value for `loss_scale` in the config if `loss_scaler` is used.'
if cfg.loss_scaler.type == 'dynamic':
# TODO: scale_window can be related to num_batches, e.g., scale_window = num_batches * 2
loss_scale_manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scaler.get('loss_scale', 2**16),
scale_factor=cfg.loss_scaler.get('scale_factor', 2.0),
scale_window=cfg.loss_scaler.get('scale_window', 2000),
scale_factor=cfg.loss_scaler.get('scale_factor', 2.0)
scale_window = cfg.loss_scaler.get('scale_window', 2000)
# adjust by gradient_accumulation_steps so that the scaling process is the same as that of batch_size=batch_size*gradient_accumulation_steps
grad_accu_steps = cfg.train.get('gradient_accumulation_steps', 1)
if grad_accu_steps > 1:
scale_factor = scale_factor ** (1/grad_accu_steps)
scale_window = scale_window * grad_accu_steps
print("INFO: gradient_accumulation_steps > 1, scale_factor and scale_window are adjusted accordingly for dynamic loss scaler")

loss_scale_manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scaler.get('loss_scale', 2**16),
scale_factor=scale_factor,
scale_window=scale_window,
)
elif cfg.loss_scaler.type == 'static':
loss_scale = cfg.loss_scaler.get('loss_scale', 1.0)
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale)
# when using static loss scaler and drop_overflow_update is False, we should also set loss_scale for optimizer.
if not cfg.system.drop_overflow_update:
optimizer_loss_scale = loss_scale
else:
raise ValueError(f'Available loss scaler types are `static` and `dynamic`, but got {cfg.loss_scaler}')
return loss_scale_manager, optimizer_loss_scale

return loss_scale_manager, optimizer_loss_scale

135 changes: 106 additions & 29 deletions mindocr/utils/train_step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ def tensor_grad_scale_row_tensor(scale, grad):
grad.dense_shape,
)


class TrainOneStepWrapper(nn.TrainOneStepWithLossScaleCell):
"""TrainStep with ema and clip grad.
Args:
drop_overflow_update: if True, network will not be updated when gradient is overflow.
scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called
scale_sense (Union[Tensor, Cell]): If this value is a Cell, it will be called
to update loss scale. If this value is a Tensor, the loss scale can be modified by `set_sense_scale`,
the shape should be :math:`()` or :math:`(1,)`.

Expand All @@ -51,25 +50,49 @@ def __init__(
ema=False,
ema_decay=0.9999,
updates=0,
clip_grad=False, #TODO: adamw/lion opt also support clip grad. merge?
clip_value=15.0,
drop_overflow_update=True,
clip_grad=False, #TODO: adamw/lion opt also support clip grad. merge?
clip_norm=1.0,
gradient_accumulation_steps=1,
verbose=False,
):
super().__init__(network, optimizer, scale_sense)
self.ema = ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
self.clip_grad = clip_grad
self.clip_value = clip_value
self.updates = Parameter(Tensor(updates, ms.float32), requires_grad=False)
self.drop_overflow_update = drop_overflow_update

assert isinstance(clip_grad, bool), f'Invalid type of clip_grad, got {type(clip_grad)}'
assert clip_norm > 0. and isinstance(clip_norm, float), f'clip_norm must be float > 1.0, but got {clip_norm}'
self.clip_grad = clip_grad
self.clip_norm = clip_norm

# Gradient accumulation
assert gradient_accumulation_steps >= 1 and isinstance(gradient_accumulation_steps, int), f'gradient_accumulation_steps must be int >= 1, but got {gradient_accumulation_steps}'
self.grad_accu_steps = gradient_accumulation_steps
if self.grad_accu_steps > 1:
# additionally caches network trainable parameters. overhead caused.
# TODO: try to store it in CPU memory instead of GPU/NPU memory.
self.accumulated_grads = optimizer.parameters.clone(prefix='grad_accumulated_', init='zeros')
self.zeros = optimizer.parameters.clone(prefix='zeros_', init='zeros')
for p in self.accumulated_grads:
p.requires_grad = False
for z in self.zeros:
z.requires_grad = False
self.cur_accu_step = Parameter(Tensor(0, ms.int32), 'grad_accumulate_step_', requires_grad=False)
self.zero = Tensor(0, ms.int32)
else:
self.cur_accu_step = 0 # it will allow to update model every step

self.verbose = verbose
if self.ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")

self.is_cpu_device = context.get_context("device_target") == 'CPU' # to support CPU run
print('\n====-> device: ', context.get_context("device_target") )
self.is_cpu_device = context.get_context("device_target") == 'CPU' # to support CPU in CI

self.map = ops.Map()
self.partial= ops.Partial()

def ema_update(self):
"""Update EMA parameters."""
Expand All @@ -81,39 +104,93 @@ def ema_update(self):
return self.updates

def construct(self, *inputs):
# compute loss
weights = self.weights
loss = self.network(*inputs)
loss = self.network(*inputs) # mini-batch loss
scaling_sens = self.scale_sense


# check loss overflow
if not self.is_cpu_device:
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
else:
status = None

scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
status = None

# up-scale loss with loss_scale value and gradient accumulation steps
# NOTE: we choose to take mean over gradient accumulation steps here for the consistency with gradient accumulation implementation in pytorch.
scaled_loss = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) / F.cast(self.grad_accu_steps, F.dtype(loss))

# compute gradients
grads = self.grad(self.network, weights)(*inputs, scaled_loss)

# down-scale gradidents with loss_scale value only. (as a result, it is the same as dividing accumulated gradients with accumulation steps)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
# apply grad reducer on grads

# gradient reduction on distributed GPUs/NPUs
grads = self.grad_reducer(grads)

# get the overflow buffer
# check gradient overflow
if not self.is_cpu_device:
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
else:
overflow = ms.Tensor(False)
cond = ms.Tensor(False)

if self.drop_overflow_update:
# if there is no overflow, do optimize
if not overflow :
loss = F.depend(loss, self.optimizer(grads))
if self.ema:
self.ema_update()
cond = ms.Tensor(False)

#print(0, grads[0][0][0])

# accumulate gradients and update model weights if no overflow or allow to update even when overflow
if (not self.drop_overflow_update) or (not overflow):
# gradient accumulation
if self.grad_accu_steps > 1:
success = F.depend(loss, self.map(self.partial(ops.assign_add), self.accumulated_grads, grads)) # self.accumulated_grads += grads
success = F.depend(success, ops.assign_add(self.cur_accu_step, Tensor(1, ms.int32))) # self.cur_accu_step += 1
accu_grads = self.accumulated_grads
else:
success = loss
accu_grads = grads

# optimize
# TODO: consider the last accumluation round, which is now skipped
if self.cur_accu_step % self.grad_accu_steps == 0:
#print(1, accu_grads[0][0][0])
# clip grad
if self.clip_grad:
clipped_grads = ops.clip_by_global_norm(accu_grads, self.clip_norm)
else:
clipped_grads = accu_grads
#print(2, clipped_grads[0][0][0])

# NOTE: no need to divde accumulated grads with accumulation steps since we've divided loss with the steps.
success = F.depend(success, self.optimizer(clipped_grads))

# EMA of model weights
#if self.ema:
# self.ema_update()

# clear grad accumulation states
if self.grad_accu_steps > 1:
success = F.depend(success, self.map(self.partial(ops.assign), self.accumulated_grads, self.zeros)) # self.accumulated_grads = 0
success = F.depend(success, ops.assign(self.cur_accu_step, self.zero)) # self.cur_accu_step = 0

else:
# still optimizer even overflow
loss = F.depend(loss, self.optimizer(grads))
if self.ema:
self.ema_update()
print("WARNING: Gradient overflow! update skipped.")
pass

return loss, cond, scaling_sens


def _get_gradient_accumulation_fn(self):
# code adopted from mindyolo
hyper_map = ops.HyperMap()

def accu_fn(g1, g2):
g1 = g1 + g2
return g1

def gradient_accumulation_fn(accumulated_grads, grads):
success = hyper_map(accu_fn, accumulated_grads, grads)
return success

return gradient_accumulation_fn


Loading