Skip to content

Commit 5391efc

Browse files
committed
Deep paper resnet
1 parent a167506 commit 5391efc

File tree

2 files changed

+39
-34
lines changed

2 files changed

+39
-34
lines changed

config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
data : "ORAND-CAR-2014/CAR-A/"
22
train_val_split: 0.8
33
epochs: 1500
4-
batch_size: 8
4+
batch_size: 32
55
lr : 1.0e-4
66
verbose: True
77
log: train_log.txt

model.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@
55
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66

77

8+
class ResBlock(nn.Module):
9+
def __init__(self, input_channels, output_channels):
10+
super(ResBlock, self).__init__()
11+
12+
self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1, stride=2)
13+
self.bn1 = nn.BatchNorm2d(output_channels)
14+
self.conv2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding=1, stride=1)
15+
self.bn2 = nn.BatchNorm2d(output_channels)
16+
self.res_conv = nn.Conv2d(in_channels=input_channels, out_channels=output_channels, kernel_size=1, padding=0, stride=2)
17+
self.bn_res = nn.BatchNorm2d(output_channels)
18+
19+
def forward(self, x):
20+
res = x
21+
x = F.relu(self.bn1(self.conv1(x)))
22+
x = self.bn2(self.conv2(x))
23+
x = F.relu(x.add(self.bn_res(self.res_conv(res))))
24+
25+
return x
26+
27+
828
class StringNet(nn.Module):
929
def __init__(self, n_classes, seq_length, batch_size):
1030
"""
@@ -18,7 +38,7 @@ def __init__(self, n_classes, seq_length, batch_size):
1838
self.batch_size = batch_size
1939
self.hidden_dim = 100
2040
self.bidirectional = True
21-
self.lstm_layers = 1
41+
self.lstm_layers = 2
2242

2343
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=1, stride=1)
2444
self.bn1 = nn.BatchNorm2d(64)
@@ -29,29 +49,17 @@ def __init__(self, n_classes, seq_length, batch_size):
2949
self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1)
3050
self.bn3 = nn.BatchNorm2d(64)
3151

32-
self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2)
33-
self.bn4 = nn.BatchNorm2d(128)
34-
self.conv5 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1, stride=1)
35-
self.bn5 = nn.BatchNorm2d(128)
36-
self.res_conv1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, padding=0, stride=2)
37-
self.bn_res1 = nn.BatchNorm2d(128)
38-
39-
self.conv6 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=2)
40-
self.bn6 = nn.BatchNorm2d(256)
41-
self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1, stride=1)
42-
self.bn7 = nn.BatchNorm2d(256)
43-
self.res_conv2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, padding=0, stride=2)
44-
self.bn_res2 = nn.BatchNorm2d(256)
45-
46-
self.conv8 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=2)
47-
self.bn8 = nn.BatchNorm2d(512)
48-
self.conv9 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1, stride=1)
49-
self.bn9 = nn.BatchNorm2d(512)
50-
self.res_conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, padding=0, stride=2)
51-
self.bn_res3 = nn.BatchNorm2d(512)
52-
53-
self.lstm = nn.LSTM(227328, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
54-
bidirectional=self.bidirectional)
52+
self.res_block1 = ResBlock(64, 128)
53+
self.res_block2 = ResBlock(128, 128)
54+
55+
self.res_block3 = ResBlock(128, 256)
56+
self.res_block4 = ResBlock(256, 256)
57+
58+
self.res_block5 = ResBlock(256, 512)
59+
self.res_block6 = ResBlock(512, 512)
60+
61+
self.lstm = nn.LSTM(5120, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
62+
bidirectional=self.bidirectional, dropout=0.5)
5563

5664
self.fc2 = nn.Linear(self.hidden_dim * self.directions, n_classes)
5765
# self.fc2 = nn.Linear(self.hidden_dim, n_classes)
@@ -76,17 +84,14 @@ def forward(self, x):
7684
x = self.bn3(self.conv3(x))
7785
x = sum1 = F.relu(x.add(res1))
7886

79-
x = F.relu(self.bn4(self.conv4(x)))
80-
x = self.bn5(self.conv5(x))
81-
x = sum2 = F.relu(x.add(self.bn_res1(self.res_conv1(sum1))))
87+
x = self.res_block1(x)
88+
x = self.res_block2(x)
8289

83-
x = F.relu(self.bn6(self.conv6(x)))
84-
x = self.bn7(self.conv7(x))
85-
x = sum3 = F.relu(x.add(self.bn_res2(self.res_conv2(sum2))))
90+
x = self.res_block3(x)
91+
x = self.res_block4(x)
8692

87-
x = F.relu(self.bn8(self.conv8(x)))
88-
x = self.bn9(self.conv9(x))
89-
x = F.relu(x.add(self.bn_res3(self.res_conv3(sum3))))
93+
x = self.res_block5(x)
94+
x = self.res_block6(x)
9095

9196
x = x.view(x.size(0), -1) # flatten
9297

0 commit comments

Comments
 (0)