Skip to content

Commit ec79894

Browse files
committed
Job #28
1 parent 5391efc commit ec79894

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

model.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, n_classes, seq_length, batch_size):
3737
self.seq_length = seq_length
3838
self.batch_size = batch_size
3939
self.hidden_dim = 100
40-
self.bidirectional = True
40+
self.bidirectional = False
4141
self.lstm_layers = 2
4242

4343
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=1, stride=1)
@@ -54,12 +54,17 @@ def __init__(self, n_classes, seq_length, batch_size):
5454

5555
self.res_block3 = ResBlock(128, 256)
5656
self.res_block4 = ResBlock(256, 256)
57+
self.res_block5 = ResBlock(256, 256)
5758

58-
self.res_block5 = ResBlock(256, 512)
59-
self.res_block6 = ResBlock(512, 512)
59+
self.res_block6 = ResBlock(256, 512)
60+
self.res_block7 = ResBlock(512, 512)
61+
self.res_block8 = ResBlock(512, 512)
6062

61-
self.lstm = nn.LSTM(5120, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
62-
bidirectional=self.bidirectional, dropout=0.5)
63+
self.lstm_forward = nn.LSTM(1024, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
64+
dropout=0.5)
65+
66+
self.lstm_backward = nn.LSTM(1024, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
67+
dropout=0.5)
6368

6469
self.fc2 = nn.Linear(self.hidden_dim * self.directions, n_classes)
6570
# self.fc2 = nn.Linear(self.hidden_dim, n_classes)
@@ -89,16 +94,28 @@ def forward(self, x):
8994

9095
x = self.res_block3(x)
9196
x = self.res_block4(x)
92-
9397
x = self.res_block5(x)
98+
9499
x = self.res_block6(x)
100+
x = self.res_block7(x)
101+
x = self.res_block8(x)
95102

96103
x = x.view(x.size(0), -1) # flatten
97104

98105
features = x.view(1, current_batch_size, -1).repeat(self.seq_length, 1, 1)
99106
hidden = self.init_hidden(current_batch_size)
100107

101-
outs, hidden = self.lstm(features, hidden)
108+
# print(features.shape)
109+
110+
outs1, _ = self.lstm_forward(features, hidden)
111+
outs2, _ = self.lstm_backward(features.flip(0), hidden)
112+
113+
# print(outs1)
114+
# print(outs2)
115+
# print(outs1.shape)
116+
# assert False
117+
118+
outs = outs1.add(outs2.flip(0))
102119

103120
# print(outs.shape)
104121
# assert False

0 commit comments

Comments
 (0)