Skip to content

Commit 8a3d82f

Browse files
committed
fix lzx1413#6
1 parent e4d4777 commit 8a3d82f

File tree

2 files changed

+35
-23
lines changed

2 files changed

+35
-23
lines changed

models/FSSD_vgg.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,19 @@ class FSSD(nn.Module):
3434
2) conv2d for localization predictions
3535
3) associated priorbox layer to produce default bounding
3636
boxes specific to the layer's feature map size.
37-
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
37+
See: https://arxiv.org/pdf/1712.00960.pdf or more details.
3838
3939
Args:
40-
phase: (string) Can be "test" or "train"
4140
base: VGG16 layers for input, size of either 300 or 500
4241
extras: extra layers that feed to multibox loc and conf layers
4342
head: "multibox head" consists of loc and conf conv layers
4443
"""
4544

46-
def __init__(self,base, extras,ft_module,pyramid_ext, head, num_classes):
45+
def __init__(self,base, extras,ft_module,pyramid_ext, head, num_classes,size):
4746
super(FSSD, self).__init__()
4847
self.num_classes = num_classes
4948
# TODO: implement __call__ in PriorBox
50-
self.size = 300
49+
self.size = size
5150

5251
# SSD network
5352
self.base = nn.ModuleList(base)
@@ -157,25 +156,33 @@ def add_extras(cfg, i, batch_norm=False):
157156
in_channels = v
158157
return layers
159158

160-
def feature_transform_module(vgg, extral):
159+
def feature_transform_module(vgg, extral,size):
160+
if size == 300:
161+
up_size = 38
162+
elif size == 512:
163+
up_size = 64
164+
161165
layers = []
162166
#conv4_3
163167
layers += [BasicConv(vgg[24].out_channels,256,kernel_size=1,padding=0)]
164168
#fc_7
165-
layers += [BasicConv(vgg[-2].out_channels,256,kernel_size=1,padding=0,up_size=38)]
166-
layers += [BasicConv(extral[-1].out_channels,256,kernel_size=1,padding=0,up_size=38)]
169+
layers += [BasicConv(vgg[-2].out_channels,256,kernel_size=1,padding=0,up_size=up_size)]
170+
layers += [BasicConv(extral[-1].out_channels,256,kernel_size=1,padding=0,up_size=up_size)]
167171
return vgg,extral,layers
168172

169-
def pyramid_feature_extractor():
170-
layers = [BasicConv(256*3,512,kernel_size=3,stride=1,padding=1),BasicConv(512,512,kernel_size=3,stride=2,padding=1), \
173+
def pyramid_feature_extractor(size):
174+
if size == 300:
175+
layers = [BasicConv(256*3,512,kernel_size=3,stride=1,padding=1),BasicConv(512,512,kernel_size=3,stride=2,padding=1), \
171176
BasicConv(512,256,kernel_size=3,stride=2,padding=1),BasicConv(256,256,kernel_size=3,stride=2,padding=1), \
172177
BasicConv(256,256,kernel_size=3,stride=1,padding=0),BasicConv(256,256,kernel_size=3,stride=1,padding=0)]
178+
elif size == 512:
179+
layers = [BasicConv(256*3,512,kernel_size=3,stride=1,padding=1),BasicConv(512,512,kernel_size=3,stride=2,padding=1), \
180+
BasicConv(512,256,kernel_size=3,stride=2,padding=1),BasicConv(256,256,kernel_size=3,stride=2,padding=1), \
181+
BasicConv(256,256,kernel_size=3,stride=2,padding=1),BasicConv(256,256,kernel_size=3,stride=2,padding=1),\
182+
BasicConv(256,256,kernel_size=4,padding=1,stride=1)]
173183
return layers
174184

175185

176-
177-
178-
179186
def multibox(fea_channels, cfg, num_classes):
180187
loc_layers = []
181188
conf_layers = []
@@ -188,18 +195,20 @@ def multibox(fea_channels, cfg, num_classes):
188195

189196
extras = {
190197
'300': [256, 512, 128, 'S', 256],
191-
'512': [256, 'S',512,],
198+
'512': [256, 512, 128, 'S', 256],
192199
}
193200
mbox = {
194201
'300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location
195-
'512': [6,6,6,6,6,4,4],
202+
'512': [6, 6, 6, 6, 6, 4, 4],
196203
}
197-
fea_channels = [512,512,256,256,256,256]
204+
fea_channels = {
205+
'300':[512,512,256,256,256,256],
206+
'512':[512,512,256,256,256,256,256]}
198207

199208
def build_net(size=300, num_classes=21):
200209
if size != 300 and size != 512:
201-
print("Error: Sorry only SSD300 and SSD512 is supported currently!")
210+
print("Error: Sorry only FSSD300 and FSSD512 is supported currently!")
202211
return
203212

204-
return FSSD(*feature_transform_module(vgg(vgg_base[str(size)], 3), add_extras(extras[str(size)], 1024)), pyramid_ext=pyramid_feature_extractor(),
205-
head = multibox(fea_channels,mbox[str(size)],num_classes), num_classes=num_classes)
213+
return FSSD(*feature_transform_module(vgg(vgg_base[str(size)], 3), add_extras(extras[str(size)], 1024),size=size), pyramid_ext=pyramid_feature_extractor(size),
214+
head = multibox(fea_channels[str(size)],mbox[str(size)],num_classes), num_classes=num_classes,size = size)

models/SSD_vgg.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def __init__(self,base, extras, head, num_classes):
3232
# SSD network
3333
self.base = nn.ModuleList(base)
3434
# Layer learns to scale the l2 normalized features from conv4_3
35-
self.L2Norm = L2Norm(512, 20)
3635
self.extras = nn.ModuleList(extras)
36+
self.L2Norm = L2Norm(512,20)
3737

3838
self.loc = nn.ModuleList(head[0])
3939
self.conf = nn.ModuleList(head[1])
@@ -112,7 +112,7 @@ def load_weights(self, base_file):
112112

113113

114114

115-
def add_extras(cfg, i, batch_norm=False):
115+
def add_extras(cfg, i, batch_norm=False,size = 300):
116116
# Extra layers added to VGG for feature scaling
117117
layers = []
118118
in_channels = i
@@ -126,6 +126,9 @@ def add_extras(cfg, i, batch_norm=False):
126126
layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
127127
flag = not flag
128128
in_channels = v
129+
if size == 512:
130+
layers.append(nn.Conv2d(in_channels,128,kernel_size=1,stride=1))
131+
layers.append(nn.Conv2d(128,256,kernel_size=4,stride=1,padding=1))
129132
return layers
130133

131134

@@ -148,11 +151,11 @@ def multibox(vgg, extra_layers, cfg, num_classes):
148151

149152
extras = {
150153
'300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
151-
'512': [256, 'S',512,],
154+
'512': [256, 'S',512,128,'S',256,128,'S',256,128,'S',256],
152155
}
153156
mbox = {
154157
'300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location
155-
'512': [6,6,6,6,6,4,4],
158+
'512': [6, 6, 6, 6, 6, 4, 4],
156159
}
157160

158161

@@ -162,5 +165,5 @@ def build_net(size=300, num_classes=21):
162165
return
163166

164167
return SSD(*multibox(vgg(vgg_base[str(size)], 3),
165-
add_extras(extras[str(size)], 1024),
168+
add_extras(extras[str(size)], 1024,size = size),
166169
mbox[str(size)], num_classes), num_classes=num_classes)

0 commit comments

Comments
 (0)