-
Notifications
You must be signed in to change notification settings - Fork 133
/
Copy pathmodel.py
142 lines (117 loc) · 5.57 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class EncoderLSTM(nn.Module):
''' Encodes navigation instructions, returning hidden state context (for
attention methods) and a decoder initial state. '''
def __init__(self, vocab_size, embedding_size, hidden_size, padding_idx,
dropout_ratio, bidirectional=False, num_layers=1):
super(EncoderLSTM, self).__init__()
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.drop = nn.Dropout(p=dropout_ratio)
self.num_directions = 2 if bidirectional else 1
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx)
self.lstm = nn.LSTM(embedding_size, hidden_size, self.num_layers,
batch_first=True, dropout=dropout_ratio,
bidirectional=bidirectional)
self.encoder2decoder = nn.Linear(hidden_size * self.num_directions,
hidden_size * self.num_directions
)
def init_state(self, inputs):
''' Initialize to zero cell states and hidden states.'''
batch_size = inputs.size(0)
h0 = Variable(torch.zeros(
self.num_layers * self.num_directions,
batch_size,
self.hidden_size
), requires_grad=False)
c0 = Variable(torch.zeros(
self.num_layers * self.num_directions,
batch_size,
self.hidden_size
), requires_grad=False)
return h0.cuda(), c0.cuda()
def forward(self, inputs, lengths):
''' Expects input vocab indices as (batch, seq_len). Also requires a
list of lengths for dynamic batching. '''
embeds = self.embedding(inputs) # (batch, seq_len, embedding_size)
embeds = self.drop(embeds)
h0, c0 = self.init_state(inputs)
packed_embeds = pack_padded_sequence(embeds, lengths, batch_first=True)
enc_h, (enc_h_t, enc_c_t) = self.lstm(packed_embeds, (h0, c0))
if self.num_directions == 2:
h_t = torch.cat((enc_h_t[-1], enc_h_t[-2]), 1)
c_t = torch.cat((enc_c_t[-1], enc_c_t[-2]), 1)
else:
h_t = enc_h_t[-1]
c_t = enc_c_t[-1] # (batch, hidden_size)
decoder_init = nn.Tanh()(self.encoder2decoder(h_t))
ctx, lengths = pad_packed_sequence(enc_h, batch_first=True)
ctx = self.drop(ctx)
return ctx,decoder_init,c_t # (batch, seq_len, hidden_size*num_directions)
# (batch, hidden_size)
class SoftDotAttention(nn.Module):
'''Soft Dot Attention.
Ref: http://www.aclweb.org/anthology/D15-1166
Adapted from PyTorch OPEN NMT.
'''
def __init__(self, dim):
'''Initialize layer.'''
super(SoftDotAttention, self).__init__()
self.linear_in = nn.Linear(dim, dim, bias=False)
self.sm = nn.Softmax(dim=1)
self.linear_out = nn.Linear(dim * 2, dim, bias=False)
self.tanh = nn.Tanh()
def forward(self, h, context, mask=None):
'''Propagate h through the network.
h: batch x dim
context: batch x seq_len x dim
mask: batch x seq_len indices to be masked
'''
target = self.linear_in(h).unsqueeze(2) # batch x dim x 1
# Get attention
attn = torch.bmm(context, target).squeeze(2) # batch x seq_len
if mask is not None:
# -Inf masking prior to the softmax
attn.data.masked_fill_(mask, -float('inf'))
attn = self.sm(attn)
attn3 = attn.view(attn.size(0), 1, attn.size(1)) # batch x 1 x seq_len
weighted_context = torch.bmm(attn3, context).squeeze(1) # batch x dim
h_tilde = torch.cat((weighted_context, h), 1)
h_tilde = self.tanh(self.linear_out(h_tilde))
return h_tilde, attn
class AttnDecoderLSTM(nn.Module):
''' An unrolled LSTM with attention over instructions for decoding navigation actions. '''
def __init__(self, input_action_size, output_action_size, embedding_size, hidden_size,
dropout_ratio, feature_size=2048):
super(AttnDecoderLSTM, self).__init__()
self.embedding_size = embedding_size
self.feature_size = feature_size
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_action_size, embedding_size)
self.drop = nn.Dropout(p=dropout_ratio)
self.lstm = nn.LSTMCell(embedding_size+feature_size, hidden_size)
self.attention_layer = SoftDotAttention(hidden_size)
self.decoder2action = nn.Linear(hidden_size, output_action_size)
def forward(self, action, feature, h_0, c_0, ctx, ctx_mask=None):
''' Takes a single step in the decoder LSTM (allowing sampling).
action: batch x 1
feature: batch x feature_size
h_0: batch x hidden_size
c_0: batch x hidden_size
ctx: batch x seq_len x dim
ctx_mask: batch x seq_len - indices to be masked
'''
action_embeds = self.embedding(action) # (batch, 1, embedding_size)
action_embeds = action_embeds.squeeze()
concat_input = torch.cat((action_embeds, feature), 1) # (batch, embedding_size+feature_size)
drop = self.drop(concat_input)
h_1,c_1 = self.lstm(drop, (h_0,c_0))
h_1_drop = self.drop(h_1)
h_tilde, alpha = self.attention_layer(h_1_drop, ctx, ctx_mask)
logit = self.decoder2action(h_tilde)
return h_1,c_1,alpha,logit