Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add train_ema #256

Merged
merged 2 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def create_parser():
help='probability of applying cutmix and/or mixup (default=0.)')
group.add_argument('--mixup', type=float, default=0.,
help='Hyperparameter of beta distribution of mixup. Recommended value is 0.2 for ImageNet. (default=0.)')
group.add_argument('--use_ema', type=str2bool, nargs='?', const=True, default=False,
help='training with ema (default=False)')
group.add_argument('--ema_decay', type=float, default=0.9999, help='ema decay')

# Model parameters
group = parser.add_argument_group('Model parameters')
Expand Down Expand Up @@ -150,6 +153,8 @@ def create_parser():
#group.add_argument('--loss_scaler', type=str, default='static', help='Loss scaler, static or dynamic (default=static)')
group.add_argument('--loss_scale', type=float, default=1.0,
help='Loss scale (default=1.0)')
group.add_argument('--dynamic_loss_scale', type=str2bool, nargs='?', const=True, default=False,
help='Whether to use dynamic loss scale (default=False)')
group.add_argument('--use_nesterov', type=str2bool, nargs='?', const=True, default=False,
help='Enables the Nesterov momentum (default=False)')
group.add_argument('--filter_bias_and_bn', type=str2bool, nargs='?', const=True, default=True,
Expand Down
19 changes: 17 additions & 2 deletions mindcv/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def create_model(
pretrained=False,
in_channels: int = 3,
checkpoint_path: str = '',
use_ema=False,
**kwargs):
r"""Creates model by name.

Expand All @@ -35,7 +36,21 @@ def create_model(
model = create_fn(**model_args, **kwargs)

if os.path.exists(checkpoint_path):
param_dict = load_checkpoint(checkpoint_path)
load_param_into_net(model, param_dict)

checkpoint_param = load_checkpoint(checkpoint_path)
ema_param_dict = dict()
for param in checkpoint_param:
if param.startswith("ema"):
new_name = param.split("ema.")[1]
ema_data = checkpoint_param[param]
ema_data.name = new_name
ema_param_dict[new_name] = ema_data

if ema_param_dict and use_ema:
load_param_into_net(model, ema_param_dict)
elif bool(ema_param_dict) is False and use_ema:
raise ValueError('chekpoint_param does not contain ema_parameter, please set use_ema is False.')
else:
load_param_into_net(model, checkpoint_param)

return model
3 changes: 2 additions & 1 deletion mindcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .download import *
from .callbacks import *
from .checkpoint_manager import *
from .reduce_manager import *
from .reduce_manager import *
from .ema import *
23 changes: 15 additions & 8 deletions mindcv/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,21 @@ def apply_eval(self):
def on_train_step_end(self, run_context):
cb_params = run_context.original_args()
num_batches = cb_params.batch_num
#global_step = cb_params.optimizer.global_step.asnumpy()[0]
cur_epoch = cb_params.cur_epoch_num + self.last_epoch -1 #(global_step-1) // num_batches
#cur_step_in_epoch = (global_step- 1) % cb_params.batch_num
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num)

if cb_params.optimizer is not None:
optimizer = cb_params.optimizer
else:
optimizer = cb_params.train_network.network.optimizer

if (cur_step_in_epoch + 1) % self.log_interval == 0 or \
(cur_step_in_epoch + 1) >= num_batches or cur_step_in_epoch == 0:
step = cb_params.optimizer.global_step
if cb_params.optimizer.dynamic_lr:
cur_lr = cb_params.optimizer.learning_rate(step-1)[0].asnumpy()
step = optimizer.global_step
if optimizer.dynamic_lr:
cur_lr = optimizer.learning_rate(step-1)[0].asnumpy()
else:
cur_lr = cb_params.optimizer.learning_rate.asnumpy()
cur_lr = optimizer.learning_rate.asnumpy()
loss = self._get_loss(cb_params)

print(f"Epoch: {cur_epoch+1}, "
Expand All @@ -123,8 +125,13 @@ def on_train_epoch_end(self, run_context):
save the best ckpt file with highest validation accuracy.
"""
cb_params = run_context.original_args()
if cb_params.optimizer is not None:
optimizer = cb_params.optimizer
else:
optimizer = cb_params.train_network.network.optimizer

# the global step may larger than batch_size * epoch due to graph mode async
global_step = cb_params.optimizer.global_step.asnumpy()[0]
global_step = optimizer.global_step.asnumpy()[0]
cur_epoch = cb_params.cur_epoch_num + self.last_epoch
cur_step_in_epoch = cb_params.batch_num #(global_step - 1) % cb_params.batch_num

Expand Down Expand Up @@ -170,7 +177,7 @@ def on_train_epoch_end(self, run_context):

# save optim for resume
optim_save_path = os.path.join(self.ckpt_dir, f'optim_{self.model_name}.ckpt')
ms.save_checkpoint(cb_params.optimizer, optim_save_path, async_save=True)
ms.save_checkpoint(optimizer, optim_save_path, async_save=True)

cur_ckpoint_file = self.model_name + "-" + str(cur_epoch) + "_" \
+ str(cur_step_in_epoch) + ".ckpt"
Expand Down
76 changes: 76 additions & 0 deletions mindcv/utils/ema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""ema define"""

import mindspore as ms
from mindspore import nn, Tensor, Parameter, ParameterTuple
from mindspore.common import RowTensor
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P

_ema_op = C.MultitypeFuncGraph("grad_ema_op")
_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")


@_ema_op.register("Tensor", "Tensor", "Tensor")
def _ema_weights(factor, ema_weight, weight):
return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor))

@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))


@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_row_tensor(scale, grad):
return RowTensor(grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape)


class TrainOneStepWithEMA(nn.TrainOneStepWithLossScaleCell):
"""TrainOneStepWithEMA"""

def __init__(self, network, optimizer, scale_sense=1.0, use_ema=False, ema_decay=0.9999, updates=0):
super(TrainOneStepWithEMA, self).__init__(network, optimizer, scale_sense)
self.use_ema = use_ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(updates, ms.float32))
if self.use_ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init='same')



def ema_update(self):
"""Update EMA parameters."""
self.updates += 1
d = self.ema_decay * (1 - F.exp(-self.updates / 2000))
# update trainable parameters
success = self.hyper_map(F.partial(_ema_op, d), self.ema_weight, self.weights_all)
self.updates = F.depend(self.updates, success)
return self.updates

def construct(self, *inputs):
"""construct"""
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense

status, scaling_sens = self.start_overflow_check(loss, scaling_sens)

scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = F.depend(loss, self.optimizer(grads))
if self.use_ema:
self.ema_update()
return loss
73 changes: 73 additions & 0 deletions tests/modules/non_cpu/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test utils"""
import os
import sys
sys.path.append('.')
import pytest
import numpy as np

import mindspore as ms
from mindspore import Tensor, nn, ops
from mindspore.common.initializer import Normal
from mindspore.nn import WithLossCell

from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.utils import TrainOneStepWithEMA

ms.set_seed(1)
np.random.seed(1)

class SimpleCNN(nn.Cell):
def __init__(self, num_classes=10, in_channels=1, include_top=True):
super(SimpleCNN, self).__init__()
self.include_top = include_top
self.conv1 = nn.Conv2d(in_channels, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)

if self.include_top:
self.flatten = nn.Flatten()
self.fc = nn.Dense(16 * 5 * 5, num_classes, weight_init=Normal(0.02))

def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
ret = x
if self.include_top:
x_flatten = self.flatten(x)
x = self.fc(x_flatten)
ret = x
return ret


@pytest.mark.parametrize('use_ema', [True, False])
@pytest.mark.parametrize('ema_decay', [0.9997, 0.5])
def test_ema(use_ema, ema_decay):
network = SimpleCNN(in_channels=1, num_classes=10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

net_opt = create_optimizer(network.trainable_params(), 'adam', lr=0.001, weight_decay=1e-7)

bs = 8
input_data = Tensor(np.ones([bs, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([bs]).astype(np.int32))

net_with_loss = WithLossCell(network, net_loss)
loss_scale_manager = Tensor(1, ms.float32)
train_network = TrainOneStepWithEMA(net_with_loss, net_opt, scale_sense=loss_scale_manager,
use_ema=use_ema, ema_decay=ema_decay)

train_network.set_train()

begin_loss = train_network(input_data, label)
for i in range(10):
cur_loss = train_network(input_data, label)
print(f"{net_opt}, begin loss: {begin_loss}, end loss: {cur_loss}")

# check output correctness
assert cur_loss < begin_loss, 'Loss does NOT decrease'
2 changes: 1 addition & 1 deletion tests/modules/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def test_checkpoint_manager(mode, ckpt_save_policy):
save_path = os.path.join('./' + f'network_{t + 1}.ckpt')
ckpoint_filelist = manager.save_ckpoint(network, num_ckpt=2, metric=acc, save_path=save_path)

assert len(ckpoint_filelist) == 2
assert len(ckpoint_filelist) == 2, "num of checkpoints is NOT correct"
67 changes: 67 additions & 0 deletions tests/tasks/non_cpu/test_train_val_imagenet_subset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
'''
Test train and validate pipelines.
For training, both graph mode and pynative mode with ms_function will be tested.
'''
import sys
sys.path.append('.')

import subprocess
import os
import pytest
from mindcv.utils.download import DownLoad

check_acc = True


@pytest.mark.parametrize('use_ema', [True, False])
@pytest.mark.parametrize('val_while_train', [True, False])
def test_train_ema(use_ema, val_while_train, model='resnet18'):
''' train on a imagenet subset dataset '''
# prepare data
data_dir = 'data/Canidae'
num_classes = 2
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"
if not os.path.exists(data_dir):
DownLoad().download_and_extract_archive(dataset_url, './')

# ---------------- test running train.py using the toy data ---------
dataset = 'imagenet'
ckpt_dir = './tests/ckpt_tmp'
num_samples = 160
num_epochs = 5
batch_size = 20
if os.path.exists(ckpt_dir):
os.system(f'rm {ckpt_dir} -rf')
if os.path.exists(data_dir):
download_str = f'--data_dir {data_dir}'
else:
download_str = '--download'
train_file = 'train.py'

cmd = f'python {train_file} --dataset={dataset} --num_classes={num_classes} --model={model} ' \
f'--epoch_size={num_epochs} --ckpt_save_interval=2 --lr=0.0001 --num_samples={num_samples} ' \
f'--loss=CE --weight_decay=1e-6 --ckpt_save_dir={ckpt_dir} {download_str} --train_split=train ' \
f'--batch_size={batch_size} --pretrained --num_parallel_workers=2 --val_while_train={val_while_train} ' \
f'--val_split=val --val_interval=1 --use_ema'

print(f'Running command: \n{cmd}')
ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr)
assert ret == 0, 'Training fails'

# --------- Test running validate.py using the trained model ------------- #
# begin_ckpt = os.path.join(ckpt_dir, f'{model}-1_1.ckpt')
end_ckpt = os.path.join(ckpt_dir, f'{model}-{num_epochs}_{num_samples // batch_size}.ckpt')
cmd = f"python validate.py --model={model} --dataset={dataset} --val_split=val --data_dir={data_dir} --num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2"
# ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr)
print(f'Running command: \n{cmd}')
p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE)
out, err = p.communicate()
# assert ret==0, 'Validation fails'
print(out)

if check_acc:
res = out.decode()
idx = res.find('Accuracy')
acc = res[idx:].split(',')[0].split(':')[1]
print('Val acc: ', acc)
assert float(acc) > 0.5, 'Acc is too low'
Loading