@@ -37,7 +37,7 @@ def __init__(self, n_classes, seq_length, batch_size):
37
37
self .seq_length = seq_length
38
38
self .batch_size = batch_size
39
39
self .hidden_dim = 100
40
- self .bidirectional = True
40
+ self .bidirectional = False
41
41
self .lstm_layers = 2
42
42
43
43
self .conv1 = nn .Conv2d (in_channels = 3 , out_channels = 64 , kernel_size = 5 , padding = 1 , stride = 1 )
@@ -54,12 +54,17 @@ def __init__(self, n_classes, seq_length, batch_size):
54
54
55
55
self .res_block3 = ResBlock (128 , 256 )
56
56
self .res_block4 = ResBlock (256 , 256 )
57
+ self .res_block5 = ResBlock (256 , 256 )
57
58
58
- self .res_block5 = ResBlock (256 , 512 )
59
- self .res_block6 = ResBlock (512 , 512 )
59
+ self .res_block6 = ResBlock (256 , 512 )
60
+ self .res_block7 = ResBlock (512 , 512 )
61
+ self .res_block8 = ResBlock (512 , 512 )
60
62
61
- self .lstm = nn .LSTM (5120 , self .hidden_dim , num_layers = self .lstm_layers , bias = True ,
62
- bidirectional = self .bidirectional , dropout = 0.5 )
63
+ self .lstm_forward = nn .LSTM (1024 , self .hidden_dim , num_layers = self .lstm_layers , bias = True ,
64
+ dropout = 0.5 )
65
+
66
+ self .lstm_backward = nn .LSTM (1024 , self .hidden_dim , num_layers = self .lstm_layers , bias = True ,
67
+ dropout = 0.5 )
63
68
64
69
self .fc2 = nn .Linear (self .hidden_dim * self .directions , n_classes )
65
70
# self.fc2 = nn.Linear(self.hidden_dim, n_classes)
@@ -89,16 +94,28 @@ def forward(self, x):
89
94
90
95
x = self .res_block3 (x )
91
96
x = self .res_block4 (x )
92
-
93
97
x = self .res_block5 (x )
98
+
94
99
x = self .res_block6 (x )
100
+ x = self .res_block7 (x )
101
+ x = self .res_block8 (x )
95
102
96
103
x = x .view (x .size (0 ), - 1 ) # flatten
97
104
98
105
features = x .view (1 , current_batch_size , - 1 ).repeat (self .seq_length , 1 , 1 )
99
106
hidden = self .init_hidden (current_batch_size )
100
107
101
- outs , hidden = self .lstm (features , hidden )
108
+ # print(features.shape)
109
+
110
+ outs1 , _ = self .lstm_forward (features , hidden )
111
+ outs2 , _ = self .lstm_backward (features .flip (0 ), hidden )
112
+
113
+ # print(outs1)
114
+ # print(outs2)
115
+ # print(outs1.shape)
116
+ # assert False
117
+
118
+ outs = outs1 .add (outs2 .flip (0 ))
102
119
103
120
# print(outs.shape)
104
121
# assert False
0 commit comments