-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.py
67 lines (52 loc) · 2.34 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence as packer, pad_packed_sequence as padder
# ----------------------------------------------------------------------------------------------------------------------
class DeepGRU(nn.Module):
def __init__(self, num_features, num_classes):
super(DeepGRU, self).__init__()
self.num_features = num_features
self.num_classes = num_classes
# Encoder
self.gru1 = nn.GRU(self.num_features, 512, 2, batch_first=True)
self.gru2 = nn.GRU(512, 256, 2, batch_first=True)
self.gru3 = nn.GRU(256, 128, 1, batch_first=True)
# Attention
self.attention = Attention(128)
# Classifier
self.classifier = nn.Sequential(
nn.BatchNorm1d(256),
nn.Dropout(0.5),
nn.Linear(256, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x_padded, x_lengths):
x_packed = packer(x_padded, x_lengths, batch_first=True)
# Encode
output, _ = self.gru1(x_packed)
output, _ = self.gru2(output)
output, hidden = self.gru3(output)
# Pass to attention with the original padding
output_padded, _ = padder(output, batch_first=True)
attn_output = self.attention(output_padded, hidden[-1:])
# Classify
return self.classifier(attn_output)
def get_num_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ----------------------------------------------------------------------------------------------------------------------
class Attention(nn.Module):
def __init__(self, attention_dim):
super(Attention, self).__init__()
self.w = nn.Linear(attention_dim, attention_dim, bias=False)
self.gru = nn.GRU(128, 128, 1, batch_first=True)
def forward(self, input_padded, hidden):
e = torch.bmm(self.w(input_padded), hidden.permute(1, 2, 0))
context = torch.bmm(input_padded.permute(0, 2, 1), e.softmax(dim=1))
context = context.permute(0, 2, 1)
# Compute the auxiliary context, and concat
aux_context, _ = self.gru(context, hidden)
output = torch.cat([aux_context, context], 2).squeeze(1)
return output