Skip to content

Commit

Permalink
add mnasnet; add save embedding only option
Browse files Browse the repository at this point in the history
  • Loading branch information
nttstar committed Jan 18, 2019
1 parent 129bd77 commit b5de354
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 6 deletions.
20 changes: 15 additions & 5 deletions recognition/sample_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
config.bn_mom = 0.9
config.workspace = 256
config.emb_size = 512
config.ckpt_embedding = True
config.net_se = 0
config.net_act = 'prelu'
config.net_unit = 3
config.net_input = 1
config.net_output = 'E'
config.net_multiplier = 1.0
config.val_targets = ['lfw', 'cfp_fp', 'agedb_30']
config.ce_loss = False
config.ce_loss = True
config.fc7_lr_mult = 1.0
config.fc7_wd_mult = 1.0
config.fc7_no_bias = False
Expand All @@ -37,24 +38,33 @@

network.y1 = edict()
network.y1.net_name = 'fmobilefacenet'
network.y1.num_layers = 1
network.y1.emb_size = 128
network.y1.net_output = 'GDC'

network.m1 = edict()
network.m1.net_name = 'fmobilenet'
network.m1.num_layers = 1
network.m1.emb_size = 256
network.m1.net_output = 'GAP'
network.m1.net_multiplier = 1.0

network.m05 = edict()
network.m05.net_name = 'fmobilenet'
network.m05.num_layers = 1
network.m05.emb_size = 256
network.m05.net_output = 'GAP'
network.m05.net_output = 'GDC'
network.m05.net_multiplier = 0.5

network.mnas05 = edict()
network.mnas05.net_name = 'fmnasnet'
network.mnas05.emb_size = 256
network.mnas05.net_output = 'GDC'
network.mnas05.net_multiplier = 0.5

network.mnas025 = edict()
network.mnas025.net_name = 'fmnasnet'
network.mnas025.emb_size = 256
network.mnas025.net_output = 'GDC'
network.mnas025.net_multiplier = 0.25

# dataset settings
dataset = edict()

Expand Down
155 changes: 155 additions & 0 deletions recognition/symbol/fmnasnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import sys
import os
import mxnet as mx
import mxnet.ndarray as nd
import mxnet.gluon as gluon
import mxnet.gluon.nn as nn
import mxnet.autograd as ag
import symbol_utils
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from config import config

def ConvBlock(channels, kernel_size, strides, **kwargs):
out = nn.HybridSequential(**kwargs)
with out.name_scope():
out.add(
nn.Conv2D(channels, kernel_size, strides=strides, padding=1, use_bias=False),
nn.BatchNorm(scale=True),
nn.Activation('relu')
)
return out

def Conv1x1(channels, is_linear=False, **kwargs):
out = nn.HybridSequential(**kwargs)
with out.name_scope():
out.add(
nn.Conv2D(channels, 1, padding=0, use_bias=False),
nn.BatchNorm(scale=True)
)
if not is_linear:
out.add(nn.Activation('relu'))
return out

def DWise(channels, strides, kernel_size=3, **kwargs):
out = nn.HybridSequential(**kwargs)
with out.name_scope():
out.add(
nn.Conv2D(channels, kernel_size, strides=strides, padding=kernel_size // 2, groups=channels, use_bias=False),
nn.BatchNorm(scale=True),
nn.Activation('relu')
)
return out

class SepCONV(nn.HybridBlock):
def __init__(self, inp, output, kernel_size, depth_multiplier=1, with_bn=True, **kwargs):
super(SepCONV, self).__init__(**kwargs)
with self.name_scope():
self.net = nn.HybridSequential()
cn = int(inp*depth_multiplier)

if output is None:
self.net.add(
nn.Conv2D(in_channels=inp, channels=cn, groups=inp, kernel_size=kernel_size, strides=(1,1), padding=kernel_size // 2
, use_bias=not with_bn)
)
else:
self.net.add(
nn.Conv2D(in_channels=inp, channels=cn, groups=inp, kernel_size=kernel_size, strides=(1,1), padding=kernel_size // 2
, use_bias=False),
nn.BatchNorm(),
nn.Activation('relu'),
nn.Conv2D(in_channels=cn, channels=output, kernel_size=(1,1), strides=(1,1)
, use_bias=not with_bn)
)

self.with_bn = with_bn
self.act = nn.Activation('relu')
if with_bn:
self.bn = nn.BatchNorm()
def hybrid_forward(self, F ,x):
x = self.net(x)
if self.with_bn:
x = self.bn(x)
if self.act is not None:
x = self.act(x)
return x

class ExpandedConv(nn.HybridBlock):
def __init__(self, inp, oup, t, strides, kernel=3, same_shape=True, **kwargs):
super(ExpandedConv, self).__init__(**kwargs)

self.same_shape = same_shape
self.strides = strides
with self.name_scope():
self.bottleneck = nn.HybridSequential()
self.bottleneck.add(
Conv1x1(inp*t, prefix="expand_"),
DWise(inp*t, self.strides, kernel, prefix="dwise_"),
Conv1x1(oup, is_linear=True, prefix="linear_")
)
def hybrid_forward(self, F, x):
out = self.bottleneck(x)
if self.strides == 1 and self.same_shape:
out = F.elemwise_add(out, x)
return out

def ExpandedConvSequence(t, k, inp, oup, repeats, first_strides, **kwargs):
seq = nn.HybridSequential(**kwargs)
with seq.name_scope():
seq.add(ExpandedConv(inp, oup, t, first_strides, k, same_shape=False))
curr_inp = oup
for i in range(1, repeats):
seq.add(ExpandedConv(curr_inp, oup, t, 1))
curr_inp = oup
return seq

class MNasNet(nn.HybridBlock):
def __init__(self, m=1.0, **kwargs):
super(MNasNet, self).__init__(**kwargs)

m = config.net_multiplier
self.first_oup = int(32*m)
#self.second_oup = int(16*m)
self.second_oup = int(32*m)
self.interverted_residual_setting = [
# t, c, n, s, k
[3, int(24*m), 3, 2, 3, "stage2_"], # -> 56x56
[3, int(40*m), 3, 2, 5, "stage3_"], # -> 28x28
[6, int(80*m), 3, 2, 5, "stage4_1_"], # -> 14x14
[6, int(96*m), 2, 1, 3, "stage4_2_"], # -> 14x14
[6, int(192*m), 4, 2, 5, "stage5_1_"], # -> 7x7
[6, int(320*m), 1, 1, 3, "stage5_2_"], # -> 7x7
]
self.last_channels = int(1024*m)

with self.name_scope():
self.features = nn.HybridSequential()
self.features.add(ConvBlock(self.first_oup, 3, 1, prefix="stage1_conv0_"))
self.features.add(SepCONV(self.first_oup, self.second_oup, 3, prefix="stage1_sepconv0_"))
inp = self.second_oup
for i, (t, c, n, s, k, prefix) in enumerate(self.interverted_residual_setting):
oup = c
self.features.add(ExpandedConvSequence(t, k, inp, oup, n, s, prefix=prefix))
inp = oup

self.features.add(Conv1x1(self.last_channels, prefix="stage5_3_"))
#self.features.add(nn.GlobalAvgPool2D())
#self.features.add(nn.Flatten())
#self.output = nn.Dense(num_classes)
def hybrid_forward(self, F, x):
x = self.features(x)
#x = self.output(x)
return x

def num_output_channel(self):
return self.last_channels

def get_symbol():
net = MNasNet(config.net_multiplier)
data = mx.sym.Variable(name='data')
data = data-127.5
data = data*0.0078125
body = net(data)
fc1 = symbol_utils.get_fc1(body, config.emb_size, config.net_output, input_channel=net.num_output_channel())
return fc1

12 changes: 11 additions & 1 deletion recognition/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import fresnet
import fmobilefacenet
import fmobilenet
import fmnasnet


logger = logging.getLogger()
Expand Down Expand Up @@ -360,7 +361,16 @@ def _batch_callback(param):
if do_save:
print('saving', msave)
arg, aux = model.get_params()
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
if config.ckpt_embedding:
all_layers = model.symbol.get_internals()
_sym = all_layers['fc1_output']
_arg = {}
for k in arg:
if not k.startswith('fc7'):
_arg[k] = arg[k]
mx.model.save_checkpoint(prefix, msave, _sym, _arg, aux)
else:
mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
print('[%d]Accuracy-Highest: %1.5f'%(mbatch, highest_acc[-1]))
if config.max_steps>0 and mbatch>config.max_steps:
sys.exit(0)
Expand Down

0 comments on commit b5de354

Please sign in to comment.