From dd65c8328ed222dd06b845dbd513748c99fdcd76 Mon Sep 17 00:00:00 2001 From: songyuanwei Date: Mon, 10 Jul 2023 17:04:46 +0800 Subject: [PATCH] fix ema and edgenext yaml bugs --- configs/edgenext/edgenext_base_ascend.yaml | 1 + configs/edgenext/edgenext_small_ascend.yaml | 1 + configs/edgenext/edgenext_x_small_ascend.yaml | 1 + configs/edgenext/edgenext_xx_small_ascend.yaml | 1 + mindcv/utils/train_step.py | 8 ++++++-- 5 files changed, 10 insertions(+), 2 deletions(-) diff --git a/configs/edgenext/edgenext_base_ascend.yaml b/configs/edgenext/edgenext_base_ascend.yaml index f08740c60..85293d0a7 100644 --- a/configs/edgenext/edgenext_base_ascend.yaml +++ b/configs/edgenext/edgenext_base_ascend.yaml @@ -60,4 +60,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/configs/edgenext/edgenext_small_ascend.yaml b/configs/edgenext/edgenext_small_ascend.yaml index 9962664fc..4c1ac8e01 100644 --- a/configs/edgenext/edgenext_small_ascend.yaml +++ b/configs/edgenext/edgenext_small_ascend.yaml @@ -59,4 +59,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/configs/edgenext/edgenext_x_small_ascend.yaml b/configs/edgenext/edgenext_x_small_ascend.yaml index 12fd8bd57..954e21fef 100644 --- a/configs/edgenext/edgenext_x_small_ascend.yaml +++ b/configs/edgenext/edgenext_x_small_ascend.yaml @@ -59,4 +59,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/configs/edgenext/edgenext_xx_small_ascend.yaml b/configs/edgenext/edgenext_xx_small_ascend.yaml index 47bea3ac6..4875dbf15 100644 --- a/configs/edgenext/edgenext_xx_small_ascend.yaml +++ b/configs/edgenext/edgenext_xx_small_ascend.yaml @@ -58,4 +58,5 @@ filter_bias_and_bn: True momentum: 0.9 weight_decay: 0.05 loss_scale: 1024 +drop_overflow_update: True use_nesterov: False diff --git a/mindcv/utils/train_step.py b/mindcv/utils/train_step.py index a23e427ef..e091e8c63 100644 --- a/mindcv/utils/train_step.py +++ b/mindcv/utils/train_step.py @@ -152,6 +152,8 @@ def construct(self, *inputs): # if there is no overflow, do optimize if not overflow: loss = self.gradient_accumulation(loss, grads) + if self.ema: + loss = F.depend(loss, self.ema_update()) else: # apply grad reducer on grads grads = self.grad_reducer(grads) @@ -161,6 +163,8 @@ def construct(self, *inputs): # if there is no overflow, do optimize if not overflow: loss = F.depend(loss, self.optimizer(grads)) + if self.ema: + loss = F.depend(loss, self.ema_update()) else: # scale_sense = loss_scale: Tensor --> TrainOneStepCell.construct if self.accumulate_grad: loss = self.gradient_accumulation(loss, grads) @@ -168,7 +172,7 @@ def construct(self, *inputs): grads = self.grad_reducer(grads) loss = F.depend(loss, self.optimizer(grads)) - if self.ema: - loss = F.depend(loss, self.ema_update()) + if self.ema: + loss = F.depend(loss, self.ema_update()) return loss