Skip to content

Commit cb7ead2

Browse files
Generalize model function for easier hyper parameter adaption.
1 parent 60541a3 commit cb7ead2

File tree

1 file changed

+74
-70
lines changed

1 file changed

+74
-70
lines changed

model.py

Lines changed: 74 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,73 +6,77 @@
66

77

88
class StringNet(nn.Module):
9-
def __init__(self, n_classes, seq_length, batch_size):
10-
"""
11-
In the constructor we instantiate two nn.Linear modules and assign them as
12-
member variables.
13-
"""
14-
super(StringNet, self).__init__()
15-
16-
self.n_classes = n_classes
17-
self.seq_length = seq_length
18-
self.batch_size = batch_size
19-
self.hidden_dim = 200
20-
21-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=0)
22-
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=0)
23-
self.pool1 = nn.MaxPool2d(2)
24-
25-
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0)
26-
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0)
27-
self.pool2 = nn.MaxPool2d(2)
28-
29-
self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0)
30-
self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=0)
31-
self.pool3 = nn.MaxPool2d(2)
32-
33-
self.fc1 = nn.Linear(128 * 9 * 34, 128 * seq_length) # depend on the flatten output,
34-
#checked if line 49. Dont know if there is an auto solution
35-
36-
self.lstm = nn.LSTM(128*seq_length, self.hidden_dim, bias=True, bidirectional=True)
37-
38-
self.fc2 = nn.Linear(self.hidden_dim*2, n_classes)
39-
40-
41-
def init_hidden(self, input_length):
42-
# The axes semantics are (num_layers * num_directions, minibatch_size, hidden_dim)
43-
return (torch.zeros(2, input_length, self.hidden_dim).to(device),
44-
torch.zeros(2, input_length, self.hidden_dim).to(device))
45-
46-
47-
def forward(self, x):
48-
"""
49-
In the forward function we accept a Variable of input data and we must
50-
return a Variable of output data. We can use Modules defined in the
51-
constructor as well as arbitrary operators on Variables.
52-
"""
53-
current_batch_size = x.shape[0]
54-
x = F.relu(self.conv1(x))
55-
x = F.relu(self.conv2(x))
56-
x = self.pool1(x)
57-
58-
x = F.relu(self.conv3(x))
59-
x = F.relu(self.conv4(x))
60-
x = self.pool2(x)
61-
62-
x = F.relu(self.conv5(x))
63-
x = F.relu(self.conv6(x))
64-
x = self.pool3(x)
65-
66-
x = x.view(x.size(0), -1) #flatten
67-
x = F.relu(self.fc1(x))
68-
69-
features = x.view(1, current_batch_size, -1).repeat(self.seq_length, 1, 1)
70-
hidden = self.init_hidden(current_batch_size)
71-
72-
outs, hidden = self.lstm(features, hidden)
73-
74-
# Decode the hidden state of the last time step
75-
outs = self.fc2(outs)
76-
outs = F.log_softmax(outs, 2)
77-
78-
return outs
9+
def __init__(self, n_classes, seq_length, batch_size):
10+
"""
11+
In the constructor we instantiate two nn.Linear modules and assign them as
12+
member variables.
13+
"""
14+
super(StringNet, self).__init__()
15+
16+
self.n_classes = n_classes
17+
self.seq_length = seq_length
18+
self.batch_size = batch_size
19+
self.hidden_dim = 200
20+
self.bidirectional = True
21+
self.lstm_layers = 1
22+
23+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=0)
24+
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=0)
25+
self.pool1 = nn.MaxPool2d(2)
26+
27+
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0)
28+
self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0)
29+
self.pool2 = nn.MaxPool2d(2)
30+
31+
self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0)
32+
self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=0)
33+
self.pool3 = nn.MaxPool2d(2)
34+
35+
self.fc1 = nn.Linear(128 * 9 * 34, 128 * seq_length) # depend on the flatten output,
36+
37+
self.lstm = nn.LSTM(128 * seq_length, self.hidden_dim, num_layers=self.lstm_layers, bias=True,
38+
bidirectional=self.bidirectional)
39+
40+
self.fc2 = nn.Linear(self.hidden_dim * self.directions, n_classes)
41+
42+
def init_hidden(self, input_length):
43+
# The axes semantics are (num_layers * num_directions, minibatch_size, hidden_dim)
44+
return (torch.zeros(self.lstm_layers * self.directions, input_length, self.hidden_dim).to(device),
45+
torch.zeros(self.lstm_layers * self.directions, input_length, self.hidden_dim).to(device))
46+
47+
def forward(self, x):
48+
"""
49+
In the forward function we accept a Variable of input data and we must
50+
return a Variable of output data. We can use Modules defined in the
51+
constructor as well as arbitrary operators on Variables.
52+
"""
53+
current_batch_size = x.shape[0]
54+
x = F.relu(self.conv1(x))
55+
x = F.relu(self.conv2(x))
56+
x = self.pool1(x)
57+
58+
x = F.relu(self.conv3(x))
59+
x = F.relu(self.conv4(x))
60+
x = self.pool2(x)
61+
62+
x = F.relu(self.conv5(x))
63+
x = F.relu(self.conv6(x))
64+
x = self.pool3(x)
65+
66+
x = x.view(x.size(0), -1) # flatten
67+
x = F.relu(self.fc1(x))
68+
69+
features = x.view(1, current_batch_size, -1).repeat(self.seq_length, 1, 1)
70+
hidden = self.init_hidden(current_batch_size)
71+
72+
outs, hidden = self.lstm(features, hidden)
73+
74+
# Decode the hidden state of the last time step
75+
outs = self.fc2(outs)
76+
outs = F.log_softmax(outs, 2)
77+
78+
return outs
79+
80+
@property
81+
def directions(self) -> int:
82+
return 2 if self.bidirectional else 1

0 commit comments

Comments
 (0)