Skip to content

Commit

Permalink
fix ema and edgenext yaml bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
songyuanwei committed Jul 10, 2023
1 parent 2f24f90 commit dd65c83
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions configs/edgenext/edgenext_base_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.05
loss_scale: 1024
drop_overflow_update: True
use_nesterov: False
1 change: 1 addition & 0 deletions configs/edgenext/edgenext_small_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.05
loss_scale: 1024
drop_overflow_update: True
use_nesterov: False
1 change: 1 addition & 0 deletions configs/edgenext/edgenext_x_small_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.05
loss_scale: 1024
drop_overflow_update: True
use_nesterov: False
1 change: 1 addition & 0 deletions configs/edgenext/edgenext_xx_small_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ filter_bias_and_bn: True
momentum: 0.9
weight_decay: 0.05
loss_scale: 1024
drop_overflow_update: True
use_nesterov: False
8 changes: 6 additions & 2 deletions mindcv/utils/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def construct(self, *inputs):
# if there is no overflow, do optimize
if not overflow:
loss = self.gradient_accumulation(loss, grads)
if self.ema:
loss = F.depend(loss, self.ema_update())
else:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
Expand All @@ -161,14 +163,16 @@ def construct(self, *inputs):
# if there is no overflow, do optimize
if not overflow:
loss = F.depend(loss, self.optimizer(grads))
if self.ema:
loss = F.depend(loss, self.ema_update())
else: # scale_sense = loss_scale: Tensor --> TrainOneStepCell.construct
if self.accumulate_grad:
loss = self.gradient_accumulation(loss, grads)
else:
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))

if self.ema:
loss = F.depend(loss, self.ema_update())
if self.ema:
loss = F.depend(loss, self.ema_update())

return loss

0 comments on commit dd65c83

Please sign in to comment.