Skip to content

Commit

Permalink
Merge pull request #483 from Songyuanwei/branch_2
Browse files Browse the repository at this point in the history
modify ema name and clip_grad
  • Loading branch information
geniuspatrick authored Mar 6, 2023
2 parents 8d21f99 + 01f0d14 commit bc51cb3
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 41 deletions.
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def create_parser():
help='Input channels (default=3)')
group.add_argument('--ckpt_save_policy', type=str, default='latest_k',
help='Checkpoint saving strategy. The optional values is None, "top_k" or "latest_k".')
group.add_argument('--use_ema', type=str2bool, nargs='?', const=True, default=False,
group.add_argument('--ema', type=str2bool, nargs='?', const=True, default=False,
help='training with ema (default=False)')
group.add_argument('--ema_decay', type=float, default=0.9999, help='ema decay')
group.add_argument('--use_clip_grad', type=str2bool, nargs='?', const=True, default=False,
group.add_argument('--clip_grad', type=str2bool, nargs='?', const=True, default=False,
help='Whether use clip grad (default=False)')
group.add_argument('--clip_value', type=float, default=15.0, help='clip value')

Expand Down
2 changes: 1 addition & 1 deletion configs/edgenext/edgenext_small_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ re_prob: 0.0
cutmix: 1.0
cutmix_prob: 0.0
auto_augment: 'randaug-m9-mstd0.5-inc1'
use_ema: True
ema: True
ema_decay: 0.9995

# model
Expand Down
8 changes: 4 additions & 4 deletions mindcv/engine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
last_epoch=0,
keep_checkpoint_max=10,
ckpt_save_policy=None,
use_ema=False,
ema=False,
dataset_sink_mode=True,
):
super().__init__()
Expand Down Expand Up @@ -89,8 +89,8 @@ def __init__(
self.start = time()
self.epoch_start = time()
self.map = ops.HyperMap()
self.use_ema = use_ema
if self.use_ema:
self.ema = ema
if self.ema:
self.online_params = ParameterTuple(self.model.train_network.get_parameters())
self.swap_params = self.online_params.clone("swap", "zeros")

Expand All @@ -103,7 +103,7 @@ def __exit__(self, *exc_args):

def apply_eval(self, run_context):
"""Model evaluation, return validation accuracy."""
if self.use_ema:
if self.ema:
cb_params = run_context.original_args()
self.map(ops.assign, self.swap_params, self.online_params)
ema_dict = dict()
Expand Down
14 changes: 7 additions & 7 deletions mindcv/engine/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ def __init__(
network,
optimizer,
scale_sense=1.0,
use_ema=False,
ema=False,
ema_decay=0.9999,
updates=0,
use_clip_grad=False,
clip_grad=False,
clip_value=15.0,
):
super(TrainStep, self).__init__(network, optimizer, scale_sense)
self.use_ema = use_ema
self.ema = ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
self.use_clip_grad = use_clip_grad
self.clip_grad = clip_grad
self.clip_value = clip_value
if self.use_ema:
if self.ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")

Expand All @@ -73,11 +73,11 @@ def construct(self, *inputs):
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
if self.use_clip_grad:
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
if self.use_ema:
if self.ema:
self.ema_update()
return loss
10 changes: 5 additions & 5 deletions mindcv/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def create_model(
pretrained=False,
in_channels: int = 3,
checkpoint_path: str = "",
use_ema=False,
ema=False,
**kwargs,
):
r"""Creates model by name.
Expand All @@ -24,7 +24,7 @@ def create_model(
pretrained (bool): Whether to load the pretrained model. Default: False.
in_channels (int): The input channels. Default: 3.
checkpoint_path (str): The path of checkpoint files. Default: "".
use_ema (bool): Whether use ema method. Default: False.
ema (bool): Whether use ema method. Default: False.
"""

if checkpoint_path != "" and pretrained:
Expand All @@ -49,10 +49,10 @@ def create_model(
ema_data.name = new_name
ema_param_dict[new_name] = ema_data

if ema_param_dict and use_ema:
if ema_param_dict and ema:
load_param_into_net(model, ema_param_dict)
elif bool(ema_param_dict) is False and use_ema:
raise ValueError("chekpoint_param does not contain ema_parameter, please set use_ema is False.")
elif bool(ema_param_dict) is False and ema:
raise ValueError("chekpoint_param does not contain ema_parameter, please set ema is False.")
else:
load_param_into_net(model, checkpoint_param)

Expand Down
2 changes: 1 addition & 1 deletion quick_start_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ network = create_model(model_name='densenet121', num_classes=num_classes, pretra

- checkpoint_path:checkpoint的路径。默认值:“ ”。

- use_ema:是否使用ema方法 默认值: False。
- ema:是否使用ema方法 默认值: False。

使用[mindcv.loss.create_loss](https://mindcv.readthedocs.io/en/latest/api/mindcv.loss.html#mindcv.loss.create_loss)接口创建损失函数(cross_entropy loss)。

Expand Down
11 changes: 4 additions & 7 deletions tests/modules/non_cpu/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from mindspore.common.initializer import Normal
from mindspore.nn import WithLossCell

from mindcv.engine import TrainStep
from mindcv.optim import create_optimizer
from mindcv.utils import TrainOneStepWithEMA

ms.set_seed(1)
np.random.seed(1)
Expand Down Expand Up @@ -46,9 +46,9 @@ def construct(self, x):
return ret


@pytest.mark.parametrize("use_ema", [True, False])
@pytest.mark.parametrize("ema", [True, False])
@pytest.mark.parametrize("ema_decay", [0.9997, 0.5])
def test_ema(use_ema, ema_decay):
def test_ema(ema, ema_decay):
network = SimpleCNN(in_channels=1, num_classes=10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

Expand All @@ -60,10 +60,7 @@ def test_ema(use_ema, ema_decay):

net_with_loss = WithLossCell(network, net_loss)
loss_scale_manager = Tensor(1, ms.float32)
train_network = TrainOneStepWithEMA(
net_with_loss, net_opt, scale_sense=loss_scale_manager, use_ema=use_ema, ema_decay=ema_decay
)

train_network = TrainStep(net_with_loss, net_opt, scale_sense=loss_scale_manager, ema=ema, ema_decay=ema_decay)
train_network.set_train()

begin_loss = train_network(input_data, label)
Expand Down
6 changes: 3 additions & 3 deletions tests/tasks/non_cpu/test_train_val_imagenet_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
check_acc = True


@pytest.mark.parametrize("use_ema", [True, False])
@pytest.mark.parametrize("ema", [True, False])
@pytest.mark.parametrize("val_while_train", [True, False])
def test_train_ema(use_ema, val_while_train, model="resnet18"):
def test_train_ema(ema, val_while_train, model="resnet18"):
"""train on a imagenet subset dataset"""
# prepare data
data_dir = "data/Canidae"
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_train_ema(use_ema, val_while_train, model="resnet18"):
f"--epoch_size={num_epochs} --ckpt_save_interval=2 --lr=0.0001 --num_samples={num_samples} "
f"--loss=CE --weight_decay=1e-6 --ckpt_save_dir={ckpt_dir} {download_str} --train_split=train "
f"--batch_size={batch_size} --pretrained --num_parallel_workers=2 --val_while_train={val_while_train} "
f"--val_split=val --val_interval=1 --use_ema"
f"--val_split=val --val_interval=1 --ema"
)

print(f"Running command: \n{cmd}")
Expand Down
12 changes: 6 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def train(args):
drop_path_rate=args.drop_path_rate,
pretrained=args.pretrained,
checkpoint_path=args.ckpt_path,
use_ema=args.use_ema,
ema=args.ema,
)

num_params = sum([param.size for param in network.get_parameters()])
Expand Down Expand Up @@ -189,7 +189,7 @@ def train(args):

# create optimizer
# TODO: consistent naming opt, name, dataset_name
if args.use_ema or args.dynamic_loss_scale:
if args.ema or args.dynamic_loss_scale:
optimizer = create_optimizer(
network.trainable_params(),
opt=args.opt,
Expand Down Expand Up @@ -223,17 +223,17 @@ def train(args):

# init model
# TODO: add dynamic_loss_scale for ema and clip_grad
if args.use_ema or args.use_clip_grad:
if args.ema or args.clip_grad:
net_with_loss = nn.WithLossCell(network, loss)
loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale)
ms.amp.auto_mixed_precision(net_with_loss, amp_level=args.amp_level)
net_with_loss = TrainStep(
net_with_loss,
optimizer,
scale_sense=loss_scale_manager,
use_ema=args.use_ema,
ema=args.ema,
ema_decay=args.ema_decay,
use_clip_grad=args.use_clip_grad,
clip_grad=args.clip_grad,
clip_value=args.clip_value,
)

Expand Down Expand Up @@ -287,7 +287,7 @@ def train(args):
model_name=args.model,
last_epoch=begin_epoch,
ckpt_save_policy=args.ckpt_save_policy,
use_ema=args.use_ema,
ema=args.ema,
dataset_sink_mode=args.dataset_sink_mode,
)

Expand Down
4 changes: 2 additions & 2 deletions tutorials/learn_about_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def train(args):
drop_path_rate=args.drop_path_rate,
pretrained=args.pretrained,
checkpoint_path=args.ckpt_path,
use_ema=args.use_ema)
ema=args.ema)
...
```

Expand Down Expand Up @@ -385,7 +385,7 @@ python train.py ... --opt momentum --filter_bias_and_bn True --weight_decay 0.00
```python
def train(args):
...
if args.use_ema:
if args.ema:
optimizer = create_optimizer(network.trainable_params(),
opt=args.opt,
lr=lr_scheduler,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/learn_about_config_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def train(args):
drop_path_rate=args.drop_path_rate,
pretrained=args.pretrained,
checkpoint_path=args.ckpt_path,
use_ema=args.use_ema)
ema=args.ema)
...
```

Expand Down Expand Up @@ -385,7 +385,7 @@ python train.py ... --opt momentum --filter_bias_and_bn True --weight_decay 0.00
```python
def train(args):
...
if args.use_ema:
if args.ema:
optimizer = create_optimizer(network.trainable_params(),
opt=args.opt,
lr=lr_scheduler,
Expand Down
2 changes: 1 addition & 1 deletion validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def validate(args):
drop_path_rate=args.drop_path_rate,
pretrained=args.pretrained,
checkpoint_path=args.ckpt_path,
use_ema=args.use_ema,
ema=args.ema,
)
network.set_train(False)

Expand Down

0 comments on commit bc51cb3

Please sign in to comment.