@@ -35,3 +35,72 @@ def vgg(cfg, i, batch_norm=False):
3535 '512' : [64 , 64 , 'M' , 128 , 128 , 'M' , 256 , 256 , 256 , 'C' , 512 , 512 , 512 , 'M' ,
3636 512 , 512 , 512 ],
3737}
38+ class BasicConv (nn .Module ):
39+
40+ def __init__ (self , in_planes , out_planes , kernel_size , stride = 1 , padding = 0 , dilation = 1 , groups = 1 , relu = True , bn = True , bias = False ):
41+ super (BasicConv , self ).__init__ ()
42+ self .out_channels = out_planes
43+ self .conv = nn .Conv2d (in_planes , out_planes , kernel_size = kernel_size , stride = stride , padding = padding , dilation = dilation , groups = groups , bias = bias )
44+ self .bn = nn .BatchNorm2d (out_planes ,eps = 1e-5 , momentum = 0.01 , affine = True ) if bn else None
45+ self .relu = nn .ReLU (inplace = True ) if relu else None
46+
47+ def forward (self , x ):
48+ x = self .conv (x )
49+ if self .bn is not None :
50+ x = self .bn (x )
51+ if self .relu is not None :
52+ x = self .relu (x )
53+ return x
54+ class BasicRFB_a (nn .Module ):
55+
56+ def __init__ (self , in_planes , out_planes , stride = 1 , scale = 0.1 ):
57+ super (BasicRFB_a , self ).__init__ ()
58+ self .scale = scale
59+ self .out_channels = out_planes
60+ inter_planes = in_planes // 4
61+
62+
63+ self .branch0 = nn .Sequential (
64+ BasicConv (in_planes , inter_planes , kernel_size = 1 , stride = 1 ),
65+ BasicConv (inter_planes , inter_planes , kernel_size = 3 , stride = 1 , padding = 1 ,relu = False )
66+ )
67+ self .branch1 = nn .Sequential (
68+ BasicConv (in_planes , inter_planes , kernel_size = 1 , stride = 1 ),
69+ BasicConv (inter_planes , inter_planes , kernel_size = (3 ,1 ), stride = 1 , padding = (1 ,0 )),
70+ BasicConv (inter_planes , inter_planes , kernel_size = 3 , stride = 1 , padding = 3 , dilation = 3 , relu = False )
71+ )
72+ self .branch2 = nn .Sequential (
73+ BasicConv (in_planes , inter_planes , kernel_size = 1 , stride = 1 ),
74+ BasicConv (inter_planes , inter_planes , kernel_size = (1 ,3 ), stride = stride , padding = (0 ,1 )),
75+ BasicConv (inter_planes , inter_planes , kernel_size = 3 , stride = 1 , padding = 3 , dilation = 3 , relu = False )
76+ )
77+ '''
78+ self.branch3 = nn.Sequential(
79+ BasicConv(in_planes, inter_planes, kernel_size=1, stride=1),
80+ BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=1),
81+ BasicConv(inter_planes, inter_planes, kernel_size=3, stride=1, padding=3, dilation=3, relu=False)
82+ )
83+ '''
84+ self .branch3 = nn .Sequential (
85+ BasicConv (in_planes , inter_planes // 2 , kernel_size = 1 , stride = 1 ),
86+ BasicConv (inter_planes // 2 , (inter_planes // 4 )* 3 , kernel_size = (1 ,3 ), stride = 1 , padding = (0 ,1 )),
87+ BasicConv ((inter_planes // 4 )* 3 , inter_planes , kernel_size = (3 ,1 ), stride = stride , padding = (1 ,0 )),
88+ BasicConv (inter_planes , inter_planes , kernel_size = 3 , stride = 1 , padding = 5 , dilation = 5 , relu = False )
89+ )
90+
91+ self .ConvLinear = BasicConv (4 * inter_planes , out_planes , kernel_size = 1 , stride = 1 , relu = False )
92+ self .shortcut = BasicConv (in_planes , out_planes , kernel_size = 1 , stride = stride , relu = False )
93+ self .relu = nn .ReLU (inplace = False )
94+ def forward (self ,x ):
95+ x0 = self .branch0 (x )
96+ x1 = self .branch1 (x )
97+ x2 = self .branch2 (x )
98+ x3 = self .branch3 (x )
99+
100+ out = torch .cat ((x0 ,x1 ,x2 ,x3 ),1 )
101+ out = self .ConvLinear (out )
102+ short = self .shortcut (x )
103+ out = out * self .scale + short
104+ out = self .relu (out )
105+
106+ return out
0 commit comments