Skip to content

Del gradient accumulation for potential loss unconvergence in dbnet #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions configs/rec/crnn/crnn_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,14 @@ loss_scaler:
scale_window: 1000

train:
gradient_accumulation_steps: 2
#gradient_accumulation_steps: 2
clip_grad: True
clip_norm: 0.0001
ckpt_save_dir: './tmp_rec'
dataset_sink_mode: False
dataset:
type: RecDataset
#dataset_root: /data/ocr_datasets
dataset_root: /Users/Samit/Data/datasets
dataset_root: /data/ocr_datasets
data_dir: ic15/rec/train/ch4_training_word_images_gt
label_file: ic15/rec/train/rec_gt.txt
sample_ratio: 1.0
Expand Down
3 changes: 2 additions & 1 deletion mindocr/utils/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def get_loss_scales(cfg):
scale_factor=cfg.loss_scaler.get('scale_factor', 2.0)
scale_window = cfg.loss_scaler.get('scale_window', 2000)
# adjust by gradient_accumulation_steps so that the scaling process is the same as that of batch_size=batch_size*gradient_accumulation_steps
'''
grad_accu_steps = cfg.train.get('gradient_accumulation_steps', 1)
if grad_accu_steps > 1:
scale_factor = scale_factor ** (1/grad_accu_steps)
scale_window = scale_window * grad_accu_steps
print("INFO: gradient_accumulation_steps > 1, scale_factor and scale_window are adjusted accordingly for dynamic loss scaler")

'''
loss_scale_manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scaler.get('loss_scale', 2**16),
scale_factor=scale_factor,
scale_window=scale_window,
Expand Down
90 changes: 15 additions & 75 deletions mindocr/utils/train_step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def __init__(
ema_decay=0.9999,
updates=0,
drop_overflow_update=True,
clip_grad=False, #TODO: adamw/lion opt also support clip grad. merge?
clip_norm=1.0,
gradient_accumulation_steps=1,
clip_grad=False,
clip_norm=1.0,
verbose=False,
):
super().__init__(network, optimizer, scale_sense)
Expand All @@ -62,28 +62,11 @@ def __init__(
self.updates = Parameter(Tensor(updates, ms.float32), requires_grad=False)
self.drop_overflow_update = drop_overflow_update

assert isinstance(clip_grad, bool), f'Invalid type of clip_grad, got {type(clip_grad)}'
assert isinstance(clip_grad, bool), f'Invalid type of clip_grad, got {type(clip_grad)}, expected bool'
assert clip_norm > 0. and isinstance(clip_norm, float), f'clip_norm must be float > 1.0, but got {clip_norm}'
self.clip_grad = clip_grad
self.clip_norm = clip_norm

# Gradient accumulation
assert gradient_accumulation_steps >= 1 and isinstance(gradient_accumulation_steps, int), f'gradient_accumulation_steps must be int >= 1, but got {gradient_accumulation_steps}'
self.grad_accu_steps = gradient_accumulation_steps
if self.grad_accu_steps > 1:
# additionally caches network trainable parameters. overhead caused.
# TODO: try to store it in CPU memory instead of GPU/NPU memory.
self.accumulated_grads = optimizer.parameters.clone(prefix='grad_accumulated_', init='zeros')
self.zeros = optimizer.parameters.clone(prefix='zeros_', init='zeros')
for p in self.accumulated_grads:
p.requires_grad = False
for z in self.zeros:
z.requires_grad = False
self.cur_accu_step = Parameter(Tensor(0, ms.int32), 'grad_accumulate_step_', requires_grad=False)
self.zero = Tensor(0, ms.int32)
else:
self.cur_accu_step = 0 # it will allow to update model every step

self.verbose = verbose
if self.ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
Expand Down Expand Up @@ -115,14 +98,12 @@ def construct(self, *inputs):
else:
status = None

# up-scale loss with loss_scale value and gradient accumulation steps
# NOTE: we choose to take mean over gradient accumulation steps here for the consistency with gradient accumulation implementation in pytorch.
scaled_loss = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) / F.cast(self.grad_accu_steps, F.dtype(loss))
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) # loss scale value

# compute gradients
grads = self.grad(self.network, weights)(*inputs, scaled_loss)
# compute gradients (of the up-scaled loss w.r.t. the model weights)
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)

# down-scale gradidents with loss_scale value only. (as a result, it is the same as dividing accumulated gradients with accumulation steps)
# down-scale gradidents with loss_scale value. TODO: divide scaling_sense by accumulation steps for grad accumulate
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)

# gradient reduction on distributed GPUs/NPUs
Expand All @@ -136,61 +117,20 @@ def construct(self, *inputs):
overflow = ms.Tensor(False)
cond = ms.Tensor(False)

#print(0, grads[0][0][0])

# accumulate gradients and update model weights if no overflow or allow to update even when overflow
if (not self.drop_overflow_update) or (not overflow):
# gradient accumulation
if self.grad_accu_steps > 1:
success = F.depend(loss, self.map(self.partial(ops.assign_add), self.accumulated_grads, grads)) # self.accumulated_grads += grads
success = F.depend(success, ops.assign_add(self.cur_accu_step, Tensor(1, ms.int32))) # self.cur_accu_step += 1
accu_grads = self.accumulated_grads
else:
success = loss
accu_grads = grads
# clip grad
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, self.clip_norm)

# optimize
# TODO: consider the last accumluation round, which is now skipped
if self.cur_accu_step % self.grad_accu_steps == 0:
#print(1, accu_grads[0][0][0])
# clip grad
if self.clip_grad:
clipped_grads = ops.clip_by_global_norm(accu_grads, self.clip_norm)
else:
clipped_grads = accu_grads
#print(2, clipped_grads[0][0][0])

# NOTE: no need to divde accumulated grads with accumulation steps since we've divided loss with the steps.
success = F.depend(success, self.optimizer(clipped_grads))

# EMA of model weights
#if self.ema:
# self.ema_update()

# clear grad accumulation states
if self.grad_accu_steps > 1:
success = F.depend(success, self.map(self.partial(ops.assign), self.accumulated_grads, self.zeros)) # self.accumulated_grads = 0
success = F.depend(success, ops.assign(self.cur_accu_step, self.zero)) # self.cur_accu_step = 0
loss = F.depend(loss, self.optimizer(grads))

# EMA of model weights
if self.ema:
self.ema_update()
else:
print("WARNING: Gradient overflow! update skipped.")
#print("WARNING: Gradient overflow! update skipped.")
pass

return loss, cond, scaling_sens


def _get_gradient_accumulation_fn(self):
# code adopted from mindyolo
hyper_map = ops.HyperMap()

def accu_fn(g1, g2):
g1 = g1 + g2
return g1

def gradient_accumulation_fn(accumulated_grads, grads):
success = hyper_map(accu_fn, accumulated_grads, grads)
return success

return gradient_accumulation_fn


6 changes: 3 additions & 3 deletions tests/st/test_train_eval_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def _create_combs():

val_while_train = False
for task in ["det", "rec"]:
for gradient_accumulation_steps in [1, 2]:
for gradient_accumulation_steps in [1]: #[1, 2]:
for clip_grad in [False, True]:
combs.add((task, val_while_train, gradient_accumulation_steps, clip_grad, grouping_strategy))

task, val_while_train, gradient_accumulation_steps, clip_grad, grouping_strategy = 'rec', False, 1, False, None
for grouping_strategy in [None, 'filter_norm_and_bias']:
for gradient_accumulation_steps in [1, 2]:
combs.add((task, val_while_train, gradient_accumulation_steps, clip_grad, grouping_strategy))
for gradient_accumulation_steps in [1]: #[1, 2]:
combs.add((task, val_while_train, gradient_accumulation_steps, clip_grad, grouping_strategy))
print(combs)
return list(combs)

Expand Down