|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class StringNet(nn.Module):
|
9 |
| - def __init__(self, n_classes, seq_length, batch_size): |
10 |
| - """ |
11 |
| - In the constructor we instantiate two nn.Linear modules and assign them as |
12 |
| - member variables. |
13 |
| - """ |
14 |
| - super(StringNet, self).__init__() |
15 |
| - |
16 |
| - self.n_classes = n_classes |
17 |
| - self.seq_length = seq_length |
18 |
| - self.batch_size = batch_size |
19 |
| - self.hidden_dim = 200 |
20 |
| - |
21 |
| - self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=0) |
22 |
| - self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=0) |
23 |
| - self.pool1 = nn.MaxPool2d(2) |
24 |
| - |
25 |
| - self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0) |
26 |
| - self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0) |
27 |
| - self.pool2 = nn.MaxPool2d(2) |
28 |
| - |
29 |
| - self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0) |
30 |
| - self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=0) |
31 |
| - self.pool3 = nn.MaxPool2d(2) |
32 |
| - |
33 |
| - self.fc1 = nn.Linear(128 * 9 * 34, 128 * seq_length) # depend on the flatten output, |
34 |
| - #checked if line 49. Dont know if there is an auto solution |
35 |
| - |
36 |
| - self.lstm = nn.LSTM(128*seq_length, self.hidden_dim, bias=True, bidirectional=True) |
37 |
| - |
38 |
| - self.fc2 = nn.Linear(self.hidden_dim*2, n_classes) |
39 |
| - |
40 |
| - |
41 |
| - def init_hidden(self, input_length): |
42 |
| - # The axes semantics are (num_layers * num_directions, minibatch_size, hidden_dim) |
43 |
| - return (torch.zeros(2, input_length, self.hidden_dim).to(device), |
44 |
| - torch.zeros(2, input_length, self.hidden_dim).to(device)) |
45 |
| - |
46 |
| - |
47 |
| - def forward(self, x): |
48 |
| - """ |
49 |
| - In the forward function we accept a Variable of input data and we must |
50 |
| - return a Variable of output data. We can use Modules defined in the |
51 |
| - constructor as well as arbitrary operators on Variables. |
52 |
| - """ |
53 |
| - current_batch_size = x.shape[0] |
54 |
| - x = F.relu(self.conv1(x)) |
55 |
| - x = F.relu(self.conv2(x)) |
56 |
| - x = self.pool1(x) |
57 |
| - |
58 |
| - x = F.relu(self.conv3(x)) |
59 |
| - x = F.relu(self.conv4(x)) |
60 |
| - x = self.pool2(x) |
61 |
| - |
62 |
| - x = F.relu(self.conv5(x)) |
63 |
| - x = F.relu(self.conv6(x)) |
64 |
| - x = self.pool3(x) |
65 |
| - |
66 |
| - x = x.view(x.size(0), -1) #flatten |
67 |
| - x = F.relu(self.fc1(x)) |
68 |
| - |
69 |
| - features = x.view(1, current_batch_size, -1).repeat(self.seq_length, 1, 1) |
70 |
| - hidden = self.init_hidden(current_batch_size) |
71 |
| - |
72 |
| - outs, hidden = self.lstm(features, hidden) |
73 |
| - |
74 |
| - # Decode the hidden state of the last time step |
75 |
| - outs = self.fc2(outs) |
76 |
| - outs = F.log_softmax(outs, 2) |
77 |
| - |
78 |
| - return outs |
| 9 | + def __init__(self, n_classes, seq_length, batch_size): |
| 10 | + """ |
| 11 | + In the constructor we instantiate two nn.Linear modules and assign them as |
| 12 | + member variables. |
| 13 | + """ |
| 14 | + super(StringNet, self).__init__() |
| 15 | + |
| 16 | + self.n_classes = n_classes |
| 17 | + self.seq_length = seq_length |
| 18 | + self.batch_size = batch_size |
| 19 | + self.hidden_dim = 200 |
| 20 | + self.bidirectional = True |
| 21 | + self.lstm_layers = 1 |
| 22 | + |
| 23 | + self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=0) |
| 24 | + self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=0) |
| 25 | + self.pool1 = nn.MaxPool2d(2) |
| 26 | + |
| 27 | + self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=0) |
| 28 | + self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0) |
| 29 | + self.pool2 = nn.MaxPool2d(2) |
| 30 | + |
| 31 | + self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=0) |
| 32 | + self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=0) |
| 33 | + self.pool3 = nn.MaxPool2d(2) |
| 34 | + |
| 35 | + self.fc1 = nn.Linear(128 * 9 * 34, 128 * seq_length) # depend on the flatten output, |
| 36 | + |
| 37 | + self.lstm = nn.LSTM(128 * seq_length, self.hidden_dim, num_layers=self.lstm_layers, bias=True, |
| 38 | + bidirectional=self.bidirectional) |
| 39 | + |
| 40 | + self.fc2 = nn.Linear(self.hidden_dim * self.directions, n_classes) |
| 41 | + |
| 42 | + def init_hidden(self, input_length): |
| 43 | + # The axes semantics are (num_layers * num_directions, minibatch_size, hidden_dim) |
| 44 | + return (torch.zeros(self.lstm_layers * self.directions, input_length, self.hidden_dim).to(device), |
| 45 | + torch.zeros(self.lstm_layers * self.directions, input_length, self.hidden_dim).to(device)) |
| 46 | + |
| 47 | + def forward(self, x): |
| 48 | + """ |
| 49 | + In the forward function we accept a Variable of input data and we must |
| 50 | + return a Variable of output data. We can use Modules defined in the |
| 51 | + constructor as well as arbitrary operators on Variables. |
| 52 | + """ |
| 53 | + current_batch_size = x.shape[0] |
| 54 | + x = F.relu(self.conv1(x)) |
| 55 | + x = F.relu(self.conv2(x)) |
| 56 | + x = self.pool1(x) |
| 57 | + |
| 58 | + x = F.relu(self.conv3(x)) |
| 59 | + x = F.relu(self.conv4(x)) |
| 60 | + x = self.pool2(x) |
| 61 | + |
| 62 | + x = F.relu(self.conv5(x)) |
| 63 | + x = F.relu(self.conv6(x)) |
| 64 | + x = self.pool3(x) |
| 65 | + |
| 66 | + x = x.view(x.size(0), -1) # flatten |
| 67 | + x = F.relu(self.fc1(x)) |
| 68 | + |
| 69 | + features = x.view(1, current_batch_size, -1).repeat(self.seq_length, 1, 1) |
| 70 | + hidden = self.init_hidden(current_batch_size) |
| 71 | + |
| 72 | + outs, hidden = self.lstm(features, hidden) |
| 73 | + |
| 74 | + # Decode the hidden state of the last time step |
| 75 | + outs = self.fc2(outs) |
| 76 | + outs = F.log_softmax(outs, 2) |
| 77 | + |
| 78 | + return outs |
| 79 | + |
| 80 | + @property |
| 81 | + def directions(self) -> int: |
| 82 | + return 2 if self.bidirectional else 1 |
0 commit comments