-
Notifications
You must be signed in to change notification settings - Fork 23
/
KgCVAE.py
418 lines (345 loc) · 16.8 KB
/
KgCVAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
#!/use/bin/python
# Author: GMFTBY
# Time: 2020.2.10
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import random
import numpy as np
import ipdb
from .layers import *
'''
KgCVAE is very similar to VHRED. The biggest difference between the VHRED and KgCVAE is the BOW (bag of the word) loss. Besides, BOW loss is compatible with the KL annealing.
Compared with VHRED, the model's architecture has a little modification, and more details can be found in the paper: Learning Discourse-level Diversity for Neural Dialog Models using Conditional Variational Autoencoders.
Code refer to: https://github.com/lipiji/dialogue-hred-vhred
'''
class Utterance_encoder(nn.Module):
'''
Bidirectional GRU
'''
def __init__(self, input_size, embedding_size,
hidden_size, dropout=0.5, n_layer=1, pretrained=None):
super(Utterance_encoder, self).__init__()
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.input_size = input_size
self.n_layer = n_layer
# self.embed = nn.Embedding(input_size, self.embedding_size)
self.gru = nn.GRU(self.embedding_size, self.hidden_size, num_layers=n_layer,
dropout=(0 if n_layer == 1 else dropout), bidirectional=True)
# hidden_project
# self.hidden_proj = nn.Linear(n_layer * 2 * self.hidden_size, hidden_size)
# self.bn = nn.BatchNorm1d(num_features=hidden_size)
self.init_weight()
def init_weight(self):
# init.xavier_normal_(self.hidden_proj.weight)
init.xavier_normal_(self.gru.weight_hh_l0)
init.xavier_normal_(self.gru.weight_ih_l0)
self.gru.bias_ih_l0.data.fill_(0.0)
self.gru.bias_hh_l0.data.fill_(0.0)
def forward(self, inpt, lengths, hidden=None):
# use pack_padded
# inpt: [seq_len, batch], lengths: [batch_size]
# embedded = self.embed(inpt) # [seq_len, batch, input_size]
if not hidden:
hidden = torch.randn(self.n_layer * 2, len(lengths),
self.hidden_size)
if torch.cuda.is_available():
hidden = hidden.cuda()
embedded = nn.utils.rnn.pack_padded_sequence(inpt, lengths,
enforce_sorted=False)
_, hidden = self.gru(embedded, hidden)
# [n_layer * bidirection, batch, hidden_size]
# hidden = hidden.reshape(hidden.shape[1], -1)
# ipdb.set_trace()
hidden = hidden.sum(axis=0) # [4, batch, hidden] -> [batch, hidden]
# hidden = hidden.permute(1, 0, 2) # [batch, n_layer * bidirectional, hidden_size]
# hidden = hidden.reshape(hidden.size(0), -1) # [batch, *]
# hidden = self.bn(hidden)
# hidden = self.hidden_proj(hidden)
hidden = torch.tanh(hidden) # [batch, hidden]
return hidden
class Context_encoder(nn.Module):
'''
input_size is 2 * utterance_hidden_size
'''
def __init__(self, input_size, hidden_size, dropout=0.5):
super(Context_encoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.gru = nn.GRU(self.input_size, self.hidden_size, bidirectional=True)
# self.drop = nn.Dropout(p=dropout)
self.init_weight()
def init_weight(self):
init.xavier_normal_(self.gru.weight_hh_l0)
init.xavier_normal_(self.gru.weight_ih_l0)
self.gru.bias_ih_l0.data.fill_(0.0)
self.gru.bias_hh_l0.data.fill_(0.0)
def forward(self, inpt, hidden=None):
# inpt: [turn_len, batch, input_size]
# hidden
if not hidden:
hidden = torch.randn(2, inpt.shape[1], self.hidden_size)
if torch.cuda.is_available():
hidden = hidden.cuda()
# inpt = self.drop(inpt)
output, hidden = self.gru(inpt, hidden)
# [seq, batch, hidden]
output = output[:, :, :self.hidden_size] + output[:, :, self.hidden_size:]
# hidden: [2, batch, hidden_size]
# hidden = hidden.squeeze(0)
hidden = torch.tanh(hidden)
return output, hidden
class VariableLayer(nn.Module):
def __init__(self, context_hidden, encoder_hidden, z_hidden):
super(VariableLayer, self).__init__()
self.context_hidden = context_hidden
self.encoder_hidden = encoder_hidden
self.z_hidden = z_hidden
self.prior_h = nn.ModuleList([nn.Linear(context_hidden, context_hidden),
nn.Linear(context_hidden, context_hidden)])
self.prior_mu = nn.Linear(context_hidden, z_hidden)
self.prior_var = nn.Linear(context_hidden, z_hidden)
self.posterior_h = nn.ModuleList([nn.Linear(context_hidden+encoder_hidden,
context_hidden),
nn.Linear(context_hidden,
context_hidden)])
self.posterior_mu = nn.Linear(context_hidden, z_hidden)
self.posterior_var = nn.Linear(context_hidden, z_hidden)
self.softplus = nn.Softplus()
def prior(self, context_outputs):
# context_outputs: [batch, context_hidden]
h_prior = context_outputs
for linear in self.prior_h:
h_prior = torch.tanh(linear(h_prior))
mu_prior = self.prior_mu(h_prior)
var_prior = self.softplus(self.prior_var(h_prior))
return mu_prior, var_prior
def posterior(self, context_outputs, encoder_hidden):
# context_outputs: [batch, context_hidden]
# encoder_hidden: [batch, encoder_hidden]
h_posterior = torch.cat([context_outputs, encoder_hidden], 1)
for linear in self.posterior_h:
h_posterior = torch.tanh(linear(h_posterior))
mu_posterior = self.posterior_mu(h_posterior)
var_posterior = self.softplus(self.posterior_var(h_posterior))
return mu_posterior, var_posterior
def kl_div(self, mu1, var1, mu2, var2):
one = torch.FloatTensor([1.0])
if torch.cuda.is_available():
one = one.cuda()
kl_div = torch.sum(0.5 * (torch.log(var2) - torch.log(var1)
+ (var1 + (mu1 - mu2).pow(2)) / var2 - one), 1)
return kl_div
def forward(self, context_outputs, encoder_hidden=None, train=True):
# context_outputs: [batch, context_hidden]
# Return: z_sent [batch, z_hidden]
# Return: kl_div, scalar for calculating the loss
mu_prior, var_prior = self.prior(context_outputs)
eps = torch.randn((context_outputs.shape[0], self.z_hidden))
if torch.cuda.is_available():
eps = eps.cuda()
if train:
mu_posterior, var_posterior = self.posterior(context_outputs,
encoder_hidden)
z_sent = mu_posterior + torch.sqrt(var_posterior) * eps
kl_div = self.kl_div(mu_posterior, var_posterior,
mu_prior, var_prior)
kl_div = torch.sum(kl_div)
else:
z_sent = mu_prior + torch.sqrt(var_prior) * eps
kl_div = None
return z_sent, kl_div
class Decoder(nn.Module):
'''
Max likelyhood for decoding the utterance
input_size is the size of the input vocabulary
Attention module should satisfy that the decoder_hidden size is the same as
the Context encoder hidden size
'''
def __init__(self, output_size, embed_size, hidden_size, n_layer=2, dropout=0.5, pretrained=None):
super(Decoder, self).__init__()
self.output_size = output_size
self.hidden_size = hidden_size
self.embed_size = embed_size
# self.embed = nn.Embedding(self.output_size, self.embed_size)
self.gru = nn.GRU(self.embed_size + self.hidden_size, self.hidden_size,
num_layers=n_layer,
dropout=(0 if n_layer == 1 else dropout))
self.out = nn.Linear(hidden_size, output_size)
# attention on context encoder
self.attn = Attention(hidden_size)
self.init_weight()
def init_weight(self):
init.xavier_normal_(self.gru.weight_hh_l0)
init.xavier_normal_(self.gru.weight_ih_l0)
self.gru.bias_ih_l0.data.fill_(0.0)
self.gru.bias_hh_l0.data.fill_(0.0)
def forward(self, inpt, last_hidden, encoder_outputs):
# inpt: [batch_size], last_hidden: [2, batch, hidden_size]
# encoder_outputs: [turn_len, batch, hidden_size]
embedded = inpt.unsqueeze(0) # [1, batch_size, embed_size]
key = last_hidden.sum(axis=0) # [batch, hidden_size]
# [batch, 1, seq_len]
attn_weights = self.attn(key, encoder_outputs)
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
context = context.transpose(0, 1) # [1, batch, hidden]
rnn_input = torch.cat([embedded, context], 2) # [1, batch, embed+hidden]
# output: [1, batch, 2*hidden_size], hidden: [2, batch, hidden_size]
output, hidden = self.gru(rnn_input, last_hidden)
output = output.squeeze(0) # [batch, hidden_size]
# context = context.squeeze(0) # [batch, hidden]
# output = torch.cat([output, context], 1) # [batch, 2 * hidden]
output = self.out(output) # [batch, output_size]
output = F.log_softmax(output, dim=1)
return output, hidden
class KgCVAE(nn.Module):
'''
Source and Target vocabulary is the same
'''
def __init__(self, embed_size, input_size, output_size,
utter_hidden, context_hidden, decoder_hidden,
teach_force=0.5, eos=24743, pad=24745, sos=24742,
unk=24745, dropout=0.5, utter_n_layer=1, z_hidden=100,
pretrained=None):
super(KgCVAE, self).__init__()
self.teach_force = teach_force
assert input_size == output_size, 'The src and tgt vocab size must be the same'
self.vocab_size = input_size
self.eos, self.pad, self.sos, self.unk = eos, pad, sos, unk
self.utter_encoder = Utterance_encoder(input_size,
embed_size,
utter_hidden,
dropout=dropout,
n_layer=utter_n_layer,
pretrained=pretrained)
self.context_encoder = Context_encoder(utter_hidden, context_hidden,
dropout=dropout)
self.decoder = Decoder(output_size, embed_size, decoder_hidden,
dropout=dropout, n_layer=utter_n_layer,
pretrained=pretrained)
self.variablelayer = VariableLayer(context_hidden,
utter_hidden, z_hidden)
self.context2decoder = nn.Linear(context_hidden+z_hidden,
context_hidden)
self.embedding = nn.Embedding(self.vocab_size, embed_size)
# BOW loss computation
self.mlp_h = nn.Linear(z_hidden, decoder_hidden)
self.mlp_p = nn.Linear(decoder_hidden, self.vocab_size)
def compute_bow_loss(self, z, tgt):
# tgt: [lengths, batch]; z: [batch, z_hidden]
target_bow = [to_bow(i, self.vocab_size, self.pad, self.sos, self.eos, self.unk) for i in tgt.transpose(0, 1)] # [batch, vocab]
target_bow = torch.stack(target_bow, dim=0) # [batch, vocab]
if torch.cuda.is_available():
target_bow = target_bow.cuda()
bow_logits = self.mlp_p(self.mlp_h(z))
bow_loss = bag_of_words_loss(bow_logits, target_bow)
return bow_loss
def forward(self, src, tgt, lengths):
# src: [turns, lengths, batch], tgt: [lengths, batch]
# lengths: [turns, batch]
turn_size, batch_size, maxlen = len(src), tgt.size(1), tgt.size(0)
outputs = torch.zeros(maxlen, batch_size, self.vocab_size)
if torch.cuda.is_available():
outputs = outputs.cuda()
# utterance encoding
turns = []
for i in range(turn_size):
# sbatch = src[i].transpose(0, 1) # [seq_len, batch]
# [4, batch, hidden]
inpt_ = self.embedding(src[i])
hidden = self.utter_encoder(inpt_, lengths[i]) # utter_hidden
turns.append(hidden)
turns = torch.stack(turns) # [turn_len, batch, utter_hidden]
# encode the tgt for KL inference in VHRED
tgt_lengths = []
for i in range(batch_size):
seq = tgt[:, i]
counter = 0
for j in seq:
if j.item() == self.pad:
break
counter += 1
tgt_lengths.append(counter)
tgt_lengths = torch.tensor(tgt_lengths, dtype=torch.long)
if torch.cuda.is_available():
tgt_lengths = tgt_lengths.cuda()
# [batch, utter_hidden]
tgt_ = self.embedding(tgt)
with torch.no_grad(): # NOTE
tgt_encoder_hidden = self.utter_encoder(tgt_, tgt_lengths)
# context encoding
# output: [seq, batch, hidden], [2, batch, hidden]
context_output, hidden = self.context_encoder(turns)
# hidden + variable z
# z_sent: [batch, z_hidden]
z_sent, kl_div = self.variablelayer(hidden.sum(axis=0),
encoder_hidden=tgt_encoder_hidden,
train=True)
# bow loss
bow_loss = self.compute_bow_loss(z_sent, tgt)
z_sent = z_sent.repeat(2, 1, 1) # [2, batch, z_hidden]
hidden = torch.cat([hidden, z_sent], dim=2) # [2, batch, z_hidden+hidden]
hidden = torch.tanh(self.context2decoder(hidden))
# decoding
# tgt = tgt.transpose(0, 1) # [seq_len, batch]
# hidden = hidden.unsqueeze(0) # [1, batch, hidden_size]
output = tgt[0, :] # [batch]
use_teacher = random.random() < self.teach_force
if use_teacher:
for t in range(1, maxlen):
output = self.embedding(output)
output, hidden = self.decoder(output, hidden, context_output)
outputs[t] = output
output = tgt[t]
else:
for t in range(1, maxlen):
output = self.embedding(output)
output, hidden = self.decoder(output, hidden, context_output)
outputs[t] = output
# output = torch.max(output, 1)[1]
output = output.topk(1)[1].squeeze().detach()
return outputs, kl_div, bow_loss # [maxlen, batch, vocab_size]
def predict(self, src, maxlen, lengths, loss=False):
# predict for test dataset, return outputs: [maxlen, batch_size]
# src: [turn, max_len, batch_size], lengths: [turn, batch_size]
with torch.no_grad():
turn_size, batch_size = len(src), src[0].size(1)
outputs = torch.zeros(maxlen, batch_size)
floss = torch.zeros(maxlen, batch_size, self.vocab_size)
if torch.cuda.is_available():
outputs = outputs.cuda()
floss = floss.cuda()
turns = []
for i in range(turn_size):
# sbatch = src[i].transpose(0, 1)
inpt_ = self.embedding(src[i])
hidden = self.utter_encoder(inpt_, lengths[i])
turns.append(hidden)
turns = torch.stack(turns)
context_output, hidden = self.context_encoder(turns)
# hidden = hidden.unsqueeze(0)
# hidden + variable z
# z_sent: [batch, z_hidden]
z_sent, kl_div = self.variablelayer(hidden.sum(axis=0),
encoder_hidden=None,
train=False)
z_sent = z_sent.repeat(2, 1, 1) # [2, batch, z_hidden]
hidden = torch.cat([hidden, z_sent], dim=2)
hidden = torch.tanh(self.context2decoder(hidden))
output = torch.zeros(batch_size, dtype=torch.long).fill_(self.sos)
if torch.cuda.is_available():
output = output.cuda()
for i in range(1, maxlen):
output = self.embedding(output)
output, hidden = self.decoder(output, hidden, context_output)
floss[i] = output
output = output.max(1)[1]
outputs[i] = output
if loss:
return outputs, floss
else:
return outputs
if __name__ == "__main__":
pass