-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #256 from Songyuanwei/branch_6
add train_ema. todo: validate while train using ema
- Loading branch information
Showing
10 changed files
with
302 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""ema define""" | ||
|
||
import mindspore as ms | ||
from mindspore import nn, Tensor, Parameter, ParameterTuple | ||
from mindspore.common import RowTensor | ||
from mindspore.ops import composite as C | ||
from mindspore.ops import functional as F | ||
from mindspore.ops import operations as P | ||
|
||
_ema_op = C.MultitypeFuncGraph("grad_ema_op") | ||
_grad_scale = C.MultitypeFuncGraph("grad_scale") | ||
reciprocal = P.Reciprocal() | ||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | ||
|
||
|
||
@_ema_op.register("Tensor", "Tensor", "Tensor") | ||
def _ema_weights(factor, ema_weight, weight): | ||
return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor)) | ||
|
||
@_grad_scale.register("Tensor", "Tensor") | ||
def tensor_grad_scale(scale, grad): | ||
return grad * F.cast(reciprocal(scale), F.dtype(grad)) | ||
|
||
|
||
@_grad_scale.register("Tensor", "RowTensor") | ||
def tensor_grad_scale_row_tensor(scale, grad): | ||
return RowTensor(grad.indices, | ||
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), | ||
grad.dense_shape) | ||
|
||
|
||
class TrainOneStepWithEMA(nn.TrainOneStepWithLossScaleCell): | ||
"""TrainOneStepWithEMA""" | ||
|
||
def __init__(self, network, optimizer, scale_sense=1.0, use_ema=False, ema_decay=0.9999, updates=0): | ||
super(TrainOneStepWithEMA, self).__init__(network, optimizer, scale_sense) | ||
self.use_ema = use_ema | ||
self.ema_decay = ema_decay | ||
self.updates = Parameter(Tensor(updates, ms.float32)) | ||
if self.use_ema: | ||
self.weights_all = ms.ParameterTuple(list(network.get_parameters())) | ||
self.ema_weight = self.weights_all.clone("ema", init='same') | ||
|
||
|
||
|
||
def ema_update(self): | ||
"""Update EMA parameters.""" | ||
self.updates += 1 | ||
d = self.ema_decay * (1 - F.exp(-self.updates / 2000)) | ||
# update trainable parameters | ||
success = self.hyper_map(F.partial(_ema_op, d), self.ema_weight, self.weights_all) | ||
self.updates = F.depend(self.updates, success) | ||
return self.updates | ||
|
||
def construct(self, *inputs): | ||
"""construct""" | ||
weights = self.weights | ||
loss = self.network(*inputs) | ||
scaling_sens = self.scale_sense | ||
|
||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens) | ||
|
||
scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) | ||
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled) | ||
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads) | ||
# apply grad reducer on grads | ||
grads = self.grad_reducer(grads) | ||
# get the overflow buffer | ||
cond = self.get_overflow_status(status, grads) | ||
overflow = self.process_loss_scale(cond) | ||
# if there is no overflow, do optimize | ||
if not overflow: | ||
loss = F.depend(loss, self.optimizer(grads)) | ||
if self.use_ema: | ||
self.ema_update() | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Test utils""" | ||
import os | ||
import sys | ||
sys.path.append('.') | ||
import pytest | ||
import numpy as np | ||
|
||
import mindspore as ms | ||
from mindspore import Tensor, nn, ops | ||
from mindspore.common.initializer import Normal | ||
from mindspore.nn import WithLossCell | ||
|
||
from mindcv.loss import create_loss | ||
from mindcv.optim import create_optimizer | ||
from mindcv.utils import TrainOneStepWithEMA | ||
|
||
ms.set_seed(1) | ||
np.random.seed(1) | ||
|
||
class SimpleCNN(nn.Cell): | ||
def __init__(self, num_classes=10, in_channels=1, include_top=True): | ||
super(SimpleCNN, self).__init__() | ||
self.include_top = include_top | ||
self.conv1 = nn.Conv2d(in_channels, 6, 5, pad_mode='valid') | ||
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | ||
self.relu = nn.ReLU() | ||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | ||
|
||
if self.include_top: | ||
self.flatten = nn.Flatten() | ||
self.fc = nn.Dense(16 * 5 * 5, num_classes, weight_init=Normal(0.02)) | ||
|
||
def construct(self, x): | ||
x = self.conv1(x) | ||
x = self.relu(x) | ||
x = self.max_pool2d(x) | ||
x = self.conv2(x) | ||
x = self.relu(x) | ||
x = self.max_pool2d(x) | ||
ret = x | ||
if self.include_top: | ||
x_flatten = self.flatten(x) | ||
x = self.fc(x_flatten) | ||
ret = x | ||
return ret | ||
|
||
|
||
@pytest.mark.parametrize('use_ema', [True, False]) | ||
@pytest.mark.parametrize('ema_decay', [0.9997, 0.5]) | ||
def test_ema(use_ema, ema_decay): | ||
network = SimpleCNN(in_channels=1, num_classes=10) | ||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||
|
||
net_opt = create_optimizer(network.trainable_params(), 'adam', lr=0.001, weight_decay=1e-7) | ||
|
||
bs = 8 | ||
input_data = Tensor(np.ones([bs, 1, 32, 32]).astype(np.float32) * 0.01) | ||
label = Tensor(np.ones([bs]).astype(np.int32)) | ||
|
||
net_with_loss = WithLossCell(network, net_loss) | ||
loss_scale_manager = Tensor(1, ms.float32) | ||
train_network = TrainOneStepWithEMA(net_with_loss, net_opt, scale_sense=loss_scale_manager, | ||
use_ema=use_ema, ema_decay=ema_decay) | ||
|
||
train_network.set_train() | ||
|
||
begin_loss = train_network(input_data, label) | ||
for i in range(10): | ||
cur_loss = train_network(input_data, label) | ||
print(f"{net_opt}, begin loss: {begin_loss}, end loss: {cur_loss}") | ||
|
||
# check output correctness | ||
assert cur_loss < begin_loss, 'Loss does NOT decrease' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
''' | ||
Test train and validate pipelines. | ||
For training, both graph mode and pynative mode with ms_function will be tested. | ||
''' | ||
import sys | ||
sys.path.append('.') | ||
|
||
import subprocess | ||
import os | ||
import pytest | ||
from mindcv.utils.download import DownLoad | ||
|
||
check_acc = True | ||
|
||
|
||
@pytest.mark.parametrize('use_ema', [True, False]) | ||
@pytest.mark.parametrize('val_while_train', [True, False]) | ||
def test_train_ema(use_ema, val_while_train, model='resnet18'): | ||
''' train on a imagenet subset dataset ''' | ||
# prepare data | ||
data_dir = 'data/Canidae' | ||
num_classes = 2 | ||
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip" | ||
if not os.path.exists(data_dir): | ||
DownLoad().download_and_extract_archive(dataset_url, './') | ||
|
||
# ---------------- test running train.py using the toy data --------- | ||
dataset = 'imagenet' | ||
ckpt_dir = './tests/ckpt_tmp' | ||
num_samples = 160 | ||
num_epochs = 5 | ||
batch_size = 20 | ||
if os.path.exists(ckpt_dir): | ||
os.system(f'rm {ckpt_dir} -rf') | ||
if os.path.exists(data_dir): | ||
download_str = f'--data_dir {data_dir}' | ||
else: | ||
download_str = '--download' | ||
train_file = 'train.py' | ||
|
||
cmd = f'python {train_file} --dataset={dataset} --num_classes={num_classes} --model={model} ' \ | ||
f'--epoch_size={num_epochs} --ckpt_save_interval=2 --lr=0.0001 --num_samples={num_samples} ' \ | ||
f'--loss=CE --weight_decay=1e-6 --ckpt_save_dir={ckpt_dir} {download_str} --train_split=train ' \ | ||
f'--batch_size={batch_size} --pretrained --num_parallel_workers=2 --val_while_train={val_while_train} ' \ | ||
f'--val_split=val --val_interval=1 --use_ema' | ||
|
||
print(f'Running command: \n{cmd}') | ||
ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) | ||
assert ret == 0, 'Training fails' | ||
|
||
# --------- Test running validate.py using the trained model ------------- # | ||
# begin_ckpt = os.path.join(ckpt_dir, f'{model}-1_1.ckpt') | ||
end_ckpt = os.path.join(ckpt_dir, f'{model}-{num_epochs}_{num_samples // batch_size}.ckpt') | ||
cmd = f"python validate.py --model={model} --dataset={dataset} --val_split=val --data_dir={data_dir} --num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2" | ||
# ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr) | ||
print(f'Running command: \n{cmd}') | ||
p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE) | ||
out, err = p.communicate() | ||
# assert ret==0, 'Validation fails' | ||
print(out) | ||
|
||
if check_acc: | ||
res = out.decode() | ||
idx = res.find('Accuracy') | ||
acc = res[idx:].split(',')[0].split(':')[1] | ||
print('Val acc: ', acc) | ||
assert float(acc) > 0.5, 'Acc is too low' |
Oops, something went wrong.