Skip to content

Commit b6e84d9

Browse files
committed
add FRFBSSD
1 parent 514a827 commit b6e84d9

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

models/base_models.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,72 @@ def vgg(cfg, i, batch_norm=False):
3535
'512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
3636
512, 512, 512],
3737
}
38+
class BasicConv(nn.Module):
39+
40+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
41+
super(BasicConv, self).__init__()
42+
self.out_channels = out_planes
43+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
44+
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
45+
self.relu = nn.ReLU(inplace=True) if relu else None
46+
47+
def forward(self, x):
48+
x = self.conv(x)
49+
if self.bn is not None:
50+
x = self.bn(x)
51+
if self.relu is not None:
52+
x = self.relu(x)
53+
return x
54+
class BasicRFB_a(nn.Module):
55+
56+
def __init__(self, in_planes, out_planes, stride=1, scale = 0.1):
57+
super(BasicRFB_a, self).__init__()
58+
self.scale = scale
59+
self.out_channels = out_planes
60+
inter_planes = in_planes //4
61+
62+
63+
self.branch0 = nn.Sequential(
64+
BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
65+
BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1,relu=False)
66+
)
67+
self.branch1 = nn.Sequential(
68+
BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
69+
BasicConv(inter_planes, inter_planes, kernel_size=(3,1), stride=1, padding=(1,0)),
70+
BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False)
71+
)
72+
self.branch2 = nn.Sequential(
73+
BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
74+
BasicConv(inter_planes, inter_planes, kernel_size=(1,3), stride=stride, padding=(0,1)),
75+
BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False)
76+
)
77+
'''
78+
self.branch3 = nn.Sequential(
79+
BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
80+
BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1),
81+
BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False)
82+
)
83+
'''
84+
self.branch3 = nn.Sequential(
85+
BasicConv(in_planes, inter_planes//2, kernel_size=1, stride=1),
86+
BasicConv(inter_planes//2, (inter_planes//4)*3, kernel_size=(1,3), stride=1, padding=(0,1)),
87+
BasicConv((inter_planes//4)*3, inter_planes, kernel_size=(3,1), stride=stride, padding=(1,0)),
88+
BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=5, dilation=5, relu=False)
89+
)
90+
91+
self.ConvLinear = BasicConv(4*inter_planes, out_planes, kernel_size=1, stride=1, relu=False)
92+
self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False)
93+
self.relu = nn.ReLU(inplace=False)
94+
def forward(self,x):
95+
x0 = self.branch0(x)
96+
x1 = self.branch1(x)
97+
x2 = self.branch2(x)
98+
x3 = self.branch3(x)
99+
100+
out = torch.cat((x0,x1,x2,x3),1)
101+
out = self.ConvLinear(out)
102+
short = self.shortcut(x)
103+
out = out*self.scale + short
104+
out = self.relu(out)
105+
106+
return out

0 commit comments

Comments
 (0)