diff --git a/config.py b/config.py index 63d39a9c..b4911a7b 100644 --- a/config.py +++ b/config.py @@ -110,6 +110,9 @@ def create_parser(): help='probability of applying cutmix and/or mixup (default=0.)') group.add_argument('--mixup', type=float, default=0., help='Hyperparameter of beta distribution of mixup. Recommended value is 0.2 for ImageNet. (default=0.)') + group.add_argument('--use_ema', type=str2bool, nargs='?', const=True, default=False, + help='training with ema (default=False)') + group.add_argument('--ema_decay', type=float, default=0.9999, help='ema decay') # Model parameters group = parser.add_argument_group('Model parameters') @@ -150,6 +153,8 @@ def create_parser(): #group.add_argument('--loss_scaler', type=str, default='static', help='Loss scaler, static or dynamic (default=static)') group.add_argument('--loss_scale', type=float, default=1.0, help='Loss scale (default=1.0)') + group.add_argument('--dynamic_loss_scale', type=str2bool, nargs='?', const=True, default=False, + help='Whether to use dynamic loss scale (default=False)') group.add_argument('--use_nesterov', type=str2bool, nargs='?', const=True, default=False, help='Enables the Nesterov momentum (default=False)') group.add_argument('--filter_bias_and_bn', type=str2bool, nargs='?', const=True, default=True, diff --git a/mindcv/models/model_factory.py b/mindcv/models/model_factory.py index 30ab9568..71698ce7 100644 --- a/mindcv/models/model_factory.py +++ b/mindcv/models/model_factory.py @@ -11,6 +11,7 @@ def create_model( pretrained=False, in_channels: int = 3, checkpoint_path: str = '', + use_ema=False, **kwargs): r"""Creates model by name. @@ -35,7 +36,21 @@ def create_model( model = create_fn(**model_args, **kwargs) if os.path.exists(checkpoint_path): - param_dict = load_checkpoint(checkpoint_path) - load_param_into_net(model, param_dict) + + checkpoint_param = load_checkpoint(checkpoint_path) + ema_param_dict = dict() + for param in checkpoint_param: + if param.startswith("ema"): + new_name = param.split("ema.")[1] + ema_data = checkpoint_param[param] + ema_data.name = new_name + ema_param_dict[new_name] = ema_data + + if ema_param_dict and use_ema: + load_param_into_net(model, ema_param_dict) + elif bool(ema_param_dict) is False and use_ema: + raise ValueError('chekpoint_param does not contain ema_parameter, please set use_ema is False.') + else: + load_param_into_net(model, checkpoint_param) return model diff --git a/mindcv/utils/__init__.py b/mindcv/utils/__init__.py index c486f437..1e793752 100644 --- a/mindcv/utils/__init__.py +++ b/mindcv/utils/__init__.py @@ -3,4 +3,5 @@ from .download import * from .callbacks import * from .checkpoint_manager import * -from .reduce_manager import * +from .reduce_manager import * +from .ema import * diff --git a/mindcv/utils/callbacks.py b/mindcv/utils/callbacks.py index 1ae03c96..6f08d515 100644 --- a/mindcv/utils/callbacks.py +++ b/mindcv/utils/callbacks.py @@ -97,19 +97,21 @@ def apply_eval(self): def on_train_step_end(self, run_context): cb_params = run_context.original_args() num_batches = cb_params.batch_num - #global_step = cb_params.optimizer.global_step.asnumpy()[0] cur_epoch = cb_params.cur_epoch_num + self.last_epoch -1 #(global_step-1) // num_batches - #cur_step_in_epoch = (global_step- 1) % cb_params.batch_num cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + if cb_params.optimizer is not None: + optimizer = cb_params.optimizer + else: + optimizer = cb_params.train_network.network.optimizer if (cur_step_in_epoch + 1) % self.log_interval == 0 or \ (cur_step_in_epoch + 1) >= num_batches or cur_step_in_epoch == 0: - step = cb_params.optimizer.global_step - if cb_params.optimizer.dynamic_lr: - cur_lr = cb_params.optimizer.learning_rate(step-1)[0].asnumpy() + step = optimizer.global_step + if optimizer.dynamic_lr: + cur_lr = optimizer.learning_rate(step-1)[0].asnumpy() else: - cur_lr = cb_params.optimizer.learning_rate.asnumpy() + cur_lr = optimizer.learning_rate.asnumpy() loss = self._get_loss(cb_params) print(f"Epoch: {cur_epoch+1}, " @@ -123,8 +125,13 @@ def on_train_epoch_end(self, run_context): save the best ckpt file with highest validation accuracy. """ cb_params = run_context.original_args() + if cb_params.optimizer is not None: + optimizer = cb_params.optimizer + else: + optimizer = cb_params.train_network.network.optimizer + # the global step may larger than batch_size * epoch due to graph mode async - global_step = cb_params.optimizer.global_step.asnumpy()[0] + global_step = optimizer.global_step.asnumpy()[0] cur_epoch = cb_params.cur_epoch_num + self.last_epoch cur_step_in_epoch = cb_params.batch_num #(global_step - 1) % cb_params.batch_num @@ -170,7 +177,7 @@ def on_train_epoch_end(self, run_context): # save optim for resume optim_save_path = os.path.join(self.ckpt_dir, f'optim_{self.model_name}.ckpt') - ms.save_checkpoint(cb_params.optimizer, optim_save_path, async_save=True) + ms.save_checkpoint(optimizer, optim_save_path, async_save=True) cur_ckpoint_file = self.model_name + "-" + str(cur_epoch) + "_" \ + str(cur_step_in_epoch) + ".ckpt" diff --git a/mindcv/utils/ema.py b/mindcv/utils/ema.py new file mode 100644 index 00000000..c240cff9 --- /dev/null +++ b/mindcv/utils/ema.py @@ -0,0 +1,76 @@ +"""ema define""" + +import mindspore as ms +from mindspore import nn, Tensor, Parameter, ParameterTuple +from mindspore.common import RowTensor +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P + +_ema_op = C.MultitypeFuncGraph("grad_ema_op") +_grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") + + +@_ema_op.register("Tensor", "Tensor", "Tensor") +def _ema_weights(factor, ema_weight, weight): + return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor)) + +@_grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * F.cast(reciprocal(scale), F.dtype(grad)) + + +@_grad_scale.register("Tensor", "RowTensor") +def tensor_grad_scale_row_tensor(scale, grad): + return RowTensor(grad.indices, + grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), + grad.dense_shape) + + +class TrainOneStepWithEMA(nn.TrainOneStepWithLossScaleCell): + """TrainOneStepWithEMA""" + + def __init__(self, network, optimizer, scale_sense=1.0, use_ema=False, ema_decay=0.9999, updates=0): + super(TrainOneStepWithEMA, self).__init__(network, optimizer, scale_sense) + self.use_ema = use_ema + self.ema_decay = ema_decay + self.updates = Parameter(Tensor(updates, ms.float32)) + if self.use_ema: + self.weights_all = ms.ParameterTuple(list(network.get_parameters())) + self.ema_weight = self.weights_all.clone("ema", init='same') + + + + def ema_update(self): + """Update EMA parameters.""" + self.updates += 1 + d = self.ema_decay * (1 - F.exp(-self.updates / 2000)) + # update trainable parameters + success = self.hyper_map(F.partial(_ema_op, d), self.ema_weight, self.weights_all) + self.updates = F.depend(self.updates, success) + return self.updates + + def construct(self, *inputs): + """construct""" + weights = self.weights + loss = self.network(*inputs) + scaling_sens = self.scale_sense + + status, scaling_sens = self.start_overflow_check(loss, scaling_sens) + + scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) + grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) + grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) + # apply grad reducer on grads + grads = self.grad_reducer(grads) + # get the overflow buffer + cond = self.get_overflow_status(status, grads) + overflow = self.process_loss_scale(cond) + # if there is no overflow, do optimize + if not overflow: + loss = F.depend(loss, self.optimizer(grads)) + if self.use_ema: + self.ema_update() + return loss diff --git a/tests/modules/non_cpu/test_utils.py b/tests/modules/non_cpu/test_utils.py new file mode 100644 index 00000000..76763bb2 --- /dev/null +++ b/tests/modules/non_cpu/test_utils.py @@ -0,0 +1,73 @@ +"""Test utils""" +import os +import sys +sys.path.append('.') +import pytest +import numpy as np + +import mindspore as ms +from mindspore import Tensor, nn, ops +from mindspore.common.initializer import Normal +from mindspore.nn import WithLossCell + +from mindcv.loss import create_loss +from mindcv.optim import create_optimizer +from mindcv.utils import TrainOneStepWithEMA + +ms.set_seed(1) +np.random.seed(1) + +class SimpleCNN(nn.Cell): + def __init__(self, num_classes=10, in_channels=1, include_top=True): + super(SimpleCNN, self).__init__() + self.include_top = include_top + self.conv1 = nn.Conv2d(in_channels, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + + if self.include_top: + self.flatten = nn.Flatten() + self.fc = nn.Dense(16 * 5 * 5, num_classes, weight_init=Normal(0.02)) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + ret = x + if self.include_top: + x_flatten = self.flatten(x) + x = self.fc(x_flatten) + ret = x + return ret + + +@pytest.mark.parametrize('use_ema', [True, False]) +@pytest.mark.parametrize('ema_decay', [0.9997, 0.5]) +def test_ema(use_ema, ema_decay): + network = SimpleCNN(in_channels=1, num_classes=10) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + net_opt = create_optimizer(network.trainable_params(), 'adam', lr=0.001, weight_decay=1e-7) + + bs = 8 + input_data = Tensor(np.ones([bs, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([bs]).astype(np.int32)) + + net_with_loss = WithLossCell(network, net_loss) + loss_scale_manager = Tensor(1, ms.float32) + train_network = TrainOneStepWithEMA(net_with_loss, net_opt, scale_sense=loss_scale_manager, + use_ema=use_ema, ema_decay=ema_decay) + + train_network.set_train() + + begin_loss = train_network(input_data, label) + for i in range(10): + cur_loss = train_network(input_data, label) + print(f"{net_opt}, begin loss: {begin_loss}, end loss: {cur_loss}") + + # check output correctness + assert cur_loss < begin_loss, 'Loss does NOT decrease' diff --git a/tests/modules/test_utils.py b/tests/modules/test_utils.py index db1a177b..b9a5a72b 100644 --- a/tests/modules/test_utils.py +++ b/tests/modules/test_utils.py @@ -81,4 +81,4 @@ def test_checkpoint_manager(mode, ckpt_save_policy): save_path = os.path.join('./' + f'network_{t + 1}.ckpt') ckpoint_filelist = manager.save_ckpoint(network, num_ckpt=2, metric=acc, save_path=save_path) - assert len(ckpoint_filelist) == 2 \ No newline at end of file + assert len(ckpoint_filelist) == 2, "num of checkpoints is NOT correct" diff --git a/tests/tasks/non_cpu/test_train_val_imagenet_subset.py b/tests/tasks/non_cpu/test_train_val_imagenet_subset.py new file mode 100644 index 00000000..3bdb1804 --- /dev/null +++ b/tests/tasks/non_cpu/test_train_val_imagenet_subset.py @@ -0,0 +1,67 @@ +''' +Test train and validate pipelines. +For training, both graph mode and pynative mode with ms_function will be tested. +''' +import sys +sys.path.append('.') + +import subprocess +import os +import pytest +from mindcv.utils.download import DownLoad + +check_acc = True + + +@pytest.mark.parametrize('use_ema', [True, False]) +@pytest.mark.parametrize('val_while_train', [True, False]) +def test_train_ema(use_ema, val_while_train, model='resnet18'): + ''' train on a imagenet subset dataset ''' + # prepare data + data_dir = 'data/Canidae' + num_classes = 2 + dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip" + if not os.path.exists(data_dir): + DownLoad().download_and_extract_archive(dataset_url, './') + + # ---------------- test running train.py using the toy data --------- + dataset = 'imagenet' + ckpt_dir = './tests/ckpt_tmp' + num_samples = 160 + num_epochs = 5 + batch_size = 20 + if os.path.exists(ckpt_dir): + os.system(f'rm {ckpt_dir} -rf') + if os.path.exists(data_dir): + download_str = f'--data_dir {data_dir}' + else: + download_str = '--download' + train_file = 'train.py' + + cmd = f'python {train_file} --dataset={dataset} --num_classes={num_classes} --model={model} ' \ + f'--epoch_size={num_epochs} --ckpt_save_interval=2 --lr=0.0001 --num_samples={num_samples} ' \ + f'--loss=CE --weight_decay=1e-6 --ckpt_save_dir={ckpt_dir} {download_str} --train_split=train ' \ + f'--batch_size={batch_size} --pretrained --num_parallel_workers=2 --val_while_train={val_while_train} ' \ + f'--val_split=val --val_interval=1 --use_ema' + + print(f'Running command: \n{cmd}') + ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) + assert ret == 0, 'Training fails' + + # --------- Test running validate.py using the trained model ------------- # + # begin_ckpt = os.path.join(ckpt_dir, f'{model}-1_1.ckpt') + end_ckpt = os.path.join(ckpt_dir, f'{model}-{num_epochs}_{num_samples // batch_size}.ckpt') + cmd = f"python validate.py --model={model} --dataset={dataset} --val_split=val --data_dir={data_dir} --num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2" + # ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) + print(f'Running command: \n{cmd}') + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE) + out, err = p.communicate() + # assert ret==0, 'Validation fails' + print(out) + + if check_acc: + res = out.decode() + idx = res.find('Accuracy') + acc = res[idx:].split(',')[0].split(':')[1] + print('Val acc: ', acc) + assert float(acc) > 0.5, 'Acc is too low' diff --git a/train.py b/train.py index cf8d938e..c73716f7 100644 --- a/train.py +++ b/train.py @@ -12,7 +12,7 @@ from mindcv.loss import create_loss from mindcv.optim import create_optimizer from mindcv.scheduler import create_scheduler -from mindcv.utils import StateMonitor, Allreduce +from mindcv.utils import StateMonitor, Allreduce, TrainOneStepWithEMA from config import parse_args ms.set_seed(1) @@ -28,7 +28,6 @@ def train(args): ''' main train function''' ms.set_context(mode=args.mode) - if args.distribute: init() device_num = get_group_size() @@ -143,7 +142,8 @@ def train(args): drop_rate=args.drop_rate, drop_path_rate=args.drop_path_rate, pretrained=args.pretrained, - checkpoint_path=args.ckpt_path) + checkpoint_path=args.ckpt_path, + use_ema=args.use_ema) num_params = sum([param.size for param in network.get_parameters()]) @@ -174,16 +174,27 @@ def train(args): # create optimizer #TODO: consistent naming opt, name, dataset_name - optimizer = create_optimizer(network.trainable_params(), - opt=args.opt, - lr=lr_scheduler, - weight_decay=args.weight_decay, - momentum=args.momentum, - nesterov=args.use_nesterov, - filter_bias_and_bn=args.filter_bias_and_bn, - loss_scale=args.loss_scale, - checkpoint_path=opt_ckpt_path, - eps=args.eps) + if args.use_ema: + optimizer = create_optimizer(network.trainable_params(), + opt=args.opt, + lr=lr_scheduler, + weight_decay=args.weight_decay, + momentum=args.momentum, + nesterov=args.use_nesterov, + filter_bias_and_bn=args.filter_bias_and_bn, + checkpoint_path=opt_ckpt_path, + eps=args.eps) + else: + optimizer = create_optimizer(network.trainable_params(), + opt=args.opt, + lr=lr_scheduler, + weight_decay=args.weight_decay, + momentum=args.momentum, + nesterov=args.use_nesterov, + filter_bias_and_bn=args.filter_bias_and_bn, + loss_scale=args.loss_scale, + checkpoint_path=opt_ckpt_path, + eps=args.eps) # Define eval metrics. if num_classes >= 5: @@ -193,12 +204,27 @@ def train(args): eval_metrics = {'Top_1_Accuracy': nn.Top1CategoricalAccuracy()} # init model - if args.loss_scale > 1.0: - loss_scale_manager = FixedLossScaleManager(loss_scale=args.loss_scale, drop_overflow_update=False) + if args.use_ema: + net_with_loss = nn.WithLossCell(network, loss) + + if args.dynamic_loss_scale: + loss_scale_manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=args.loss_scale, scale_factor=2, + scale_window=1000) + else: + loss_scale_manager = nn.FixedLossScaleUpdateCell(loss_scale_value=args.loss_scale) + ms.amp.auto_mixed_precision(net_with_loss, amp_level=args.amp_level) + net_with_loss = TrainOneStepWithEMA(net_with_loss, optimizer, scale_sense=loss_scale_manager, + use_ema=args.use_ema, ema_decay=args.ema_decay) + eval_network = nn.WithEvalCell(network, loss, args.amp_level in ["O2", "O3", "auto"]) + model = Model(net_with_loss, eval_network=eval_network, metrics=eval_metrics, eval_indexes=[0, 1, 2]) + else: + if args.dynamic_loss_scale: + loss_scale_manager = ms.amp.DynamicLossScaleManager(init_loss_scale=args.loss_scale, scale_factor=2, + scale_window=1000) + else: + loss_scale_manager = FixedLossScaleManager(loss_scale=args.loss_scale, drop_overflow_update=False) model = Model(network, loss_fn=loss, optimizer=optimizer, metrics=eval_metrics, amp_level=args.amp_level, loss_scale_manager=loss_scale_manager) - else: - model = Model(network, loss_fn=loss, optimizer=optimizer, metrics=eval_metrics, amp_level=args.amp_level) # callback # save checkpoint, summary training loss @@ -216,8 +242,7 @@ def train(args): state_cb = StateMonitor(model, summary_dir=summary_dir, dataset_val=loader_eval, val_interval=args.val_interval, - metric_name=list(eval_metrics.keys()) -, + metric_name=list(eval_metrics.keys()), ckpt_dir=args.ckpt_save_dir, ckpt_save_interval=args.ckpt_save_interval, best_ckpt_name=args.model + '_best.ckpt', diff --git a/validate.py b/validate.py index 6dc848f2..200b1c12 100644 --- a/validate.py +++ b/validate.py @@ -55,7 +55,8 @@ def validate(args): drop_rate=args.drop_rate, drop_path_rate=args.drop_path_rate, pretrained=args.pretrained, - checkpoint_path=args.ckpt_path) + checkpoint_path=args.ckpt_path, + use_ema=args.use_ema) network.set_train(False) # create loss