5
5
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
6
6
7
7
8
+ class ResBlock (nn .Module ):
9
+ def __init__ (self , input_channels , output_channels ):
10
+ super (ResBlock , self ).__init__ ()
11
+
12
+ self .conv1 = nn .Conv2d (in_channels = input_channels , out_channels = output_channels , kernel_size = 3 , padding = 1 , stride = 2 )
13
+ self .bn1 = nn .BatchNorm2d (output_channels )
14
+ self .conv2 = nn .Conv2d (in_channels = output_channels , out_channels = output_channels , kernel_size = 3 , padding = 1 , stride = 1 )
15
+ self .bn2 = nn .BatchNorm2d (output_channels )
16
+ self .res_conv = nn .Conv2d (in_channels = input_channels , out_channels = output_channels , kernel_size = 1 , padding = 0 , stride = 2 )
17
+ self .bn_res = nn .BatchNorm2d (output_channels )
18
+
19
+ def forward (self , x ):
20
+ res = x
21
+ x = F .relu (self .bn1 (self .conv1 (x )))
22
+ x = self .bn2 (self .conv2 (x ))
23
+ x = F .relu (x .add (self .bn_res (self .res_conv (res ))))
24
+
25
+ return x
26
+
27
+
8
28
class StringNet (nn .Module ):
9
29
def __init__ (self , n_classes , seq_length , batch_size ):
10
30
"""
@@ -18,7 +38,7 @@ def __init__(self, n_classes, seq_length, batch_size):
18
38
self .batch_size = batch_size
19
39
self .hidden_dim = 100
20
40
self .bidirectional = True
21
- self .lstm_layers = 1
41
+ self .lstm_layers = 2
22
42
23
43
self .conv1 = nn .Conv2d (in_channels = 3 , out_channels = 64 , kernel_size = 5 , padding = 1 , stride = 1 )
24
44
self .bn1 = nn .BatchNorm2d (64 )
@@ -29,29 +49,17 @@ def __init__(self, n_classes, seq_length, batch_size):
29
49
self .conv3 = nn .Conv2d (in_channels = 64 , out_channels = 64 , kernel_size = 3 , padding = 1 , stride = 1 )
30
50
self .bn3 = nn .BatchNorm2d (64 )
31
51
32
- self .conv4 = nn .Conv2d (in_channels = 64 , out_channels = 128 , kernel_size = 3 , padding = 1 , stride = 2 )
33
- self .bn4 = nn .BatchNorm2d (128 )
34
- self .conv5 = nn .Conv2d (in_channels = 128 , out_channels = 128 , kernel_size = 3 , padding = 1 , stride = 1 )
35
- self .bn5 = nn .BatchNorm2d (128 )
36
- self .res_conv1 = nn .Conv2d (in_channels = 64 , out_channels = 128 , kernel_size = 1 , padding = 0 , stride = 2 )
37
- self .bn_res1 = nn .BatchNorm2d (128 )
38
-
39
- self .conv6 = nn .Conv2d (in_channels = 128 , out_channels = 256 , kernel_size = 3 , padding = 1 , stride = 2 )
40
- self .bn6 = nn .BatchNorm2d (256 )
41
- self .conv7 = nn .Conv2d (in_channels = 256 , out_channels = 256 , kernel_size = 3 , padding = 1 , stride = 1 )
42
- self .bn7 = nn .BatchNorm2d (256 )
43
- self .res_conv2 = nn .Conv2d (in_channels = 128 , out_channels = 256 , kernel_size = 1 , padding = 0 , stride = 2 )
44
- self .bn_res2 = nn .BatchNorm2d (256 )
45
-
46
- self .conv8 = nn .Conv2d (in_channels = 256 , out_channels = 512 , kernel_size = 3 , padding = 1 , stride = 2 )
47
- self .bn8 = nn .BatchNorm2d (512 )
48
- self .conv9 = nn .Conv2d (in_channels = 512 , out_channels = 512 , kernel_size = 3 , padding = 1 , stride = 1 )
49
- self .bn9 = nn .BatchNorm2d (512 )
50
- self .res_conv3 = nn .Conv2d (in_channels = 256 , out_channels = 512 , kernel_size = 1 , padding = 0 , stride = 2 )
51
- self .bn_res3 = nn .BatchNorm2d (512 )
52
-
53
- self .lstm = nn .LSTM (227328 , self .hidden_dim , num_layers = self .lstm_layers , bias = True ,
54
- bidirectional = self .bidirectional )
52
+ self .res_block1 = ResBlock (64 , 128 )
53
+ self .res_block2 = ResBlock (128 , 128 )
54
+
55
+ self .res_block3 = ResBlock (128 , 256 )
56
+ self .res_block4 = ResBlock (256 , 256 )
57
+
58
+ self .res_block5 = ResBlock (256 , 512 )
59
+ self .res_block6 = ResBlock (512 , 512 )
60
+
61
+ self .lstm = nn .LSTM (5120 , self .hidden_dim , num_layers = self .lstm_layers , bias = True ,
62
+ bidirectional = self .bidirectional , dropout = 0.5 )
55
63
56
64
self .fc2 = nn .Linear (self .hidden_dim * self .directions , n_classes )
57
65
# self.fc2 = nn.Linear(self.hidden_dim, n_classes)
@@ -76,17 +84,14 @@ def forward(self, x):
76
84
x = self .bn3 (self .conv3 (x ))
77
85
x = sum1 = F .relu (x .add (res1 ))
78
86
79
- x = F .relu (self .bn4 (self .conv4 (x )))
80
- x = self .bn5 (self .conv5 (x ))
81
- x = sum2 = F .relu (x .add (self .bn_res1 (self .res_conv1 (sum1 ))))
87
+ x = self .res_block1 (x )
88
+ x = self .res_block2 (x )
82
89
83
- x = F .relu (self .bn6 (self .conv6 (x )))
84
- x = self .bn7 (self .conv7 (x ))
85
- x = sum3 = F .relu (x .add (self .bn_res2 (self .res_conv2 (sum2 ))))
90
+ x = self .res_block3 (x )
91
+ x = self .res_block4 (x )
86
92
87
- x = F .relu (self .bn8 (self .conv8 (x )))
88
- x = self .bn9 (self .conv9 (x ))
89
- x = F .relu (x .add (self .bn_res3 (self .res_conv3 (sum3 ))))
93
+ x = self .res_block5 (x )
94
+ x = self .res_block6 (x )
90
95
91
96
x = x .view (x .size (0 ), - 1 ) # flatten
92
97
0 commit comments