Skip to content

Commit

Permalink
add densenet
Browse files Browse the repository at this point in the history
  • Loading branch information
nttstar committed Jan 21, 2019
1 parent 067f648 commit c5805b8
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 20 deletions.
12 changes: 11 additions & 1 deletion recognition/sample_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
network.r50v1.num_layers = 50
network.r50v1.net_unit = 1

network.d169 = edict()
network.d169.net_name = 'fdensenet'
network.d169.num_layers = 169
network.d169.per_batch_size = 64

network.y1 = edict()
network.y1.net_name = 'fmobilefacenet'
Expand All @@ -50,7 +54,7 @@
network.m1 = edict()
network.m1.net_name = 'fmobilenet'
network.m1.emb_size = 256
network.m1.net_output = 'GAP'
network.m1.net_output = 'GDC'
network.m1.net_multiplier = 1.0

network.m05 = edict()
Expand All @@ -59,6 +63,12 @@
network.m05.net_output = 'GDC'
network.m05.net_multiplier = 0.5

network.mnas = edict()
network.mnas.net_name = 'fmnasnet'
network.mnas.emb_size = 256
network.mnas.net_output = 'GDC'
network.mnas.net_multiplier = 1.0

network.mnas05 = edict()
network.mnas05.net_name = 'fmnasnet'
network.mnas05.emb_size = 256
Expand Down
145 changes: 145 additions & 0 deletions recognition/symbol/fdensenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable= arguments-differ
"""DenseNet, implemented in Gluon."""

import sys
import os
import mxnet as mx
import mxnet.ndarray as nd
import mxnet.gluon as gluon
import mxnet.gluon.nn as nn
import mxnet.autograd as ag
import symbol_utils
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from config import config

def Act():
if config.net_act=='prelu':
return nn.PReLU()
else:
return nn.Activation(config.net_act)
# Helpers
def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index):
out = nn.HybridSequential(prefix='stage%d_'%stage_index)
with out.name_scope():
for _ in range(num_layers):
out.add(_make_dense_layer(growth_rate, bn_size, dropout))
return out

def _make_dense_layer(growth_rate, bn_size, dropout):
new_features = nn.HybridSequential(prefix='')
new_features.add(nn.BatchNorm())
#new_features.add(nn.Activation('relu'))
new_features.add(Act())
new_features.add(nn.Conv2D(bn_size * growth_rate, kernel_size=1, use_bias=False))
new_features.add(nn.BatchNorm())
#new_features.add(nn.Activation('relu'))
new_features.add(Act())
new_features.add(nn.Conv2D(growth_rate, kernel_size=3, padding=1, use_bias=False))
if dropout:
new_features.add(nn.Dropout(dropout))

out = gluon.contrib.nn.HybridConcurrent(axis=1, prefix='')
out.add(gluon.contrib.nn.Identity())
out.add(new_features)

return out

def _make_transition(num_output_features):
out = nn.HybridSequential(prefix='')
out.add(nn.BatchNorm())
#out.add(nn.Activation('relu'))
out.add(Act())
out.add(nn.Conv2D(num_output_features, kernel_size=1, use_bias=False))
out.add(nn.AvgPool2D(pool_size=2, strides=2))
return out

# Net
class DenseNet(nn.HybridBlock):
r"""Densenet-BC model from the
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
Parameters
----------
num_init_features : int
Number of filters to learn in the first convolution layer.
growth_rate : int
Number of filters to add each layer (`k` in the paper).
block_config : list of int
List of integers for numbers of layers in each pooling block.
bn_size : int, default 4
Multiplicative factor for number of bottle neck layers.
(i.e. bn_size * k features in the bottleneck layer)
dropout : float, default 0
Rate of dropout after each dense layer.
classes : int, default 1000
Number of classification classes.
"""
def __init__(self, num_init_features, growth_rate, block_config,
bn_size=4, dropout=0, classes=1000, **kwargs):

super(DenseNet, self).__init__(**kwargs)
with self.name_scope():
self.features = nn.HybridSequential(prefix='')
self.features.add(nn.Conv2D(num_init_features, kernel_size=3,
strides=1, padding=1, use_bias=False))
self.features.add(nn.BatchNorm())
self.features.add(nn.Activation('relu'))
self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1))
# Add dense blocks
num_features = num_init_features
for i, num_layers in enumerate(block_config):
self.features.add(_make_dense_block(num_layers, bn_size, growth_rate, dropout, i+1))
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
self.features.add(_make_transition(num_features // 2))
num_features = num_features // 2
self.features.add(nn.BatchNorm())
self.features.add(nn.Activation('relu'))
#self.features.add(nn.AvgPool2D(pool_size=7))
#self.features.add(nn.Flatten())

#self.output = nn.Dense(classes)

def hybrid_forward(self, F, x):
x = self.features(x)
#x = self.output(x)
return x


# Specification
densenet_spec = {121: (64, 32, [6, 12, 24, 16]),
161: (96, 48, [6, 12, 36, 24]),
169: (64, 32, [6, 12, 32, 32]),
201: (64, 32, [6, 12, 48, 32])}


# Constructor
def get_symbol():
num_layers = config.num_layers
num_init_features, growth_rate, block_config = densenet_spec[num_layers]
net = DenseNet(num_init_features, growth_rate, block_config)
data = mx.sym.Variable(name='data')
data = data-127.5
data = data*0.0078125
body = net(data)
fc1 = symbol_utils.get_fc1(body, config.emb_size, config.net_output)
return fc1

5 changes: 2 additions & 3 deletions recognition/symbol/fmnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,9 @@ class MNasNet(nn.HybridBlock):
def __init__(self, m=1.0, **kwargs):
super(MNasNet, self).__init__(**kwargs)

m = config.net_multiplier
self.first_oup = int(32*m)
#self.second_oup = int(16*m)
self.second_oup = int(32*m)
self.second_oup = int(16*m)
#self.second_oup = int(32*m)
self.interverted_residual_setting = [
# t, c, n, s, k
[3, int(24*m), 3, 2, 3, "stage2_"], # -> 56x56
Expand Down
16 changes: 2 additions & 14 deletions recognition/symbol/fmobilefacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from config import config

bn_mom = 0.9
#bn_mom = 0.9997

def Act(data, act_type, name):
#ignore param act_type, set it in this function
Expand All @@ -19,13 +17,13 @@ def Act(data, act_type, name):

def Conv(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=bn_mom)
bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=config.bn_mom)
act = Act(data=bn, act_type=config.net_act, name='%s%s_relu' %(name, suffix))
return act

def Linear(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=bn_mom)
bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=False,momentum=config.bn_mom)
return bn

def ConvOnly(data, num_filter=1, kernel=(1, 1), stride=(1, 1), pad=(0, 0), num_group=1, name=None, suffix=''):
Expand All @@ -50,17 +48,7 @@ def Residual(data, num_block=1, num_out=1, kernel=(3, 3), stride=(1, 1), pad=(1,

def get_symbol():
num_classes = config.emb_size
bn_mom = config.bn_mom
workspace = config.workspace
print('in_network', config)
#kwargs = {'version_se' : config.net_se,
# 'version_input': config.net_input,
# 'version_output': config.net_output,
# 'version_unit': config.net_unit,
# 'version_act': config.net_act,
# 'bn_mom': bn_mom,
# 'workspace': workspace,
# }
fc_type = config.net_output
data = mx.symbol.Variable(name="data")
data = data-127.5
Expand Down
91 changes: 90 additions & 1 deletion recognition/symbol/symbol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,94 @@ def residual_unit_v3(data, num_filter, stride, dim_match, name, **kwargs):
shortcut._set_attr(mirror_stage='True')
return bn3 + shortcut

def residual_unit_v1l(data, num_filter, stride, dim_match, name, bottle_neck):
"""Return ResNet Unit symbol for building ResNet
Parameters
----------
data : str
Input data
num_filter : int
Number of output channels
bnf : int
Bottle neck channels factor with regard to num_filter
stride : tuple
Stride used in convolution
dim_match : Boolean
True means channel number between input and output is the same, otherwise means differ
name : str
Base name of the operators
workspace : int
Workspace used in convolution operator
"""
workspace = config.workspace
bn_mom = config.bn_mom
memonger = False
use_se = config.net_se
act_type = config.net_act
#print('in unit1')
if bottle_neck:
conv1 = Conv(data=data, num_filter=int(num_filter*0.25), kernel=(1,1), stride=(1,1), pad=(0,0),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn1')
act1 = Act(data=bn1, act_type=act_type, name=name + '_relu1')
conv2 = Conv(data=act1, num_filter=int(num_filter*0.25), kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
bn2 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn2')
act2 = Act(data=bn2, act_type=act_type, name=name + '_relu2')
conv3 = Conv(data=act2, num_filter=num_filter, kernel=(1,1), stride=stride, pad=(0,0), no_bias=True,
workspace=workspace, name=name + '_conv3')
bn3 = mx.sym.BatchNorm(data=conv3, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_bn3')

if use_se:
#se begin
body = mx.sym.Pooling(data=bn3, global_pool=True, kernel=(7, 7), pool_type='avg', name=name+'_se_pool1')
body = Conv(data=body, num_filter=num_filter//16, kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_se_conv1", workspace=workspace)
body = Act(data=body, act_type=act_type, name=name+'_se_relu1')
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_se_conv2", workspace=workspace)
body = mx.symbol.Activation(data=body, act_type='sigmoid', name=name+"_se_sigmoid")
bn3 = mx.symbol.broadcast_mul(bn3, body)
#se end

if dim_match:
shortcut = data
else:
conv1sc = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_conv1sc')
shortcut = mx.sym.BatchNorm(data=conv1sc, fix_gamma=False, eps=2e-5, momentum=bn_mom, name=name + '_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return Act(data=bn3 + shortcut, act_type=act_type, name=name + '_relu3')
else:
conv1 = Conv(data=data, num_filter=num_filter, kernel=(3,3), stride=(1,1), pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv1')
bn1 = mx.sym.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn1')
act1 = Act(data=bn1, act_type=act_type, name=name + '_relu1')
conv2 = Conv(data=act1, num_filter=num_filter, kernel=(3,3), stride=stride, pad=(1,1),
no_bias=True, workspace=workspace, name=name + '_conv2')
bn2 = mx.sym.BatchNorm(data=conv2, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_bn2')
if use_se:
#se begin
body = mx.sym.Pooling(data=bn2, global_pool=True, kernel=(7, 7), pool_type='avg', name=name+'_se_pool1')
body = Conv(data=body, num_filter=num_filter//16, kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_se_conv1", workspace=workspace)
body = Act(data=body, act_type=act_type, name=name+'_se_relu1')
body = Conv(data=body, num_filter=num_filter, kernel=(1,1), stride=(1,1), pad=(0,0),
name=name+"_se_conv2", workspace=workspace)
body = mx.symbol.Activation(data=body, act_type='sigmoid', name=name+"_se_sigmoid")
bn2 = mx.symbol.broadcast_mul(bn2, body)
#se end

if dim_match:
shortcut = data
else:
conv1sc = Conv(data=data, num_filter=num_filter, kernel=(1,1), stride=stride, no_bias=True,
workspace=workspace, name=name+'_conv1sc')
shortcut = mx.sym.BatchNorm(data=conv1sc, fix_gamma=False, momentum=bn_mom, eps=2e-5, name=name + '_sc')
if memonger:
shortcut._set_attr(mirror_stage='True')
return Act(data=bn2 + shortcut, act_type=act_type, name=name + '_relu3')

def get_head(data, version_input, num_filter):
bn_mom = config.bn_mom
Expand All @@ -160,7 +248,8 @@ def get_head(data, version_input, num_filter):
no_bias=True, name="conv0", workspace=workspace)
body = mx.sym.BatchNorm(data=body, fix_gamma=False, eps=2e-5, momentum=bn_mom, name='bn0')
body = Act(data=body, act_type=config.net_act, name='relu0')
body = residual_unit_v3(body, _num_filter, (2, 2), False, name='head', **kwargs)
#body = residual_unit_v3(body, _num_filter, (2, 2), False, name='head', **kwargs)
body = residual_unit_v1l(body, _num_filter, (2, 2), False, name='head', bottle_neck=False)
return body


2 changes: 1 addition & 1 deletion recognition/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import fmobilefacenet
import fmobilenet
import fmnasnet
import fdensenet


logger = logging.getLogger()
Expand Down Expand Up @@ -394,7 +395,6 @@ def _batch_callback(param):
epoch_end_callback = epoch_cb )

def main():
#time.sleep(3600*6.5)
global args
args = parse_args()
train_net(args)
Expand Down

0 comments on commit c5805b8

Please sign in to comment.