forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModels.py
667 lines (568 loc) · 23.5 KB
/
Models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
import onmt
from onmt.Utils import aeq
def rnn_factory(rnn_type, **kwargs):
# Use pytorch version when available.
no_pack_padded_seq = False
if rnn_type == "SRU":
# SRU doesn't support PackedSequence.
no_pack_padded_seq = True
rnn = onmt.modules.SRU(**kwargs)
else:
rnn = getattr(nn, rnn_type)(**kwargs)
return rnn, no_pack_padded_seq
class EncoderBase(nn.Module):
"""
Base encoder class. Specifies the interface used by different encoder types
and required by :obj:`onmt.Models.NMTModel`.
.. mermaid::
graph BT
A[Input]
subgraph RNN
C[Pos 1]
D[Pos 2]
E[Pos N]
end
F[Memory_Bank]
G[Final]
A-->C
A-->D
A-->E
C-->F
D-->F
E-->F
E-->G
"""
def _check_args(self, input, lengths=None, hidden=None):
s_len, n_batch, n_feats = input.size()
if lengths is not None:
n_batch_, = lengths.size()
aeq(n_batch, n_batch_)
def forward(self, src, lengths=None, encoder_state=None):
"""
Args:
src (:obj:`LongTensor`):
padded sequences of sparse indices `[src_len x batch x nfeat]`
lengths (:obj:`LongTensor`): length of each sequence `[batch]`
encoder_state (rnn-class specific):
initial encoder_state state.
Returns:
(tuple of :obj:`FloatTensor`, :obj:`FloatTensor`):
* final encoder state, used to initialize decoder
* memory bank for attention, `[src_len x batch x hidden]`
"""
raise NotImplementedError
class MeanEncoder(EncoderBase):
"""A trivial non-recurrent encoder. Simply applies mean pooling.
Args:
num_layers (int): number of replicated layers
embeddings (:obj:`onmt.modules.Embeddings`): embedding module to use
"""
def __init__(self, num_layers, embeddings):
super(MeanEncoder, self).__init__()
self.num_layers = num_layers
self.embeddings = embeddings
def forward(self, src, lengths=None, encoder_state=None):
"See :obj:`EncoderBase.forward()`"
self._check_args(src, lengths, encoder_state)
emb = self.embeddings(src)
s_len, batch, emb_dim = emb.size()
mean = emb.mean(0).expand(self.num_layers, batch, emb_dim)
memory_bank = emb
encoder_final = (mean, mean)
return encoder_final, memory_bank
class RNNEncoder(EncoderBase):
""" A generic recurrent neural network encoder.
Args:
rnn_type (:obj:`str`):
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
bidirectional (bool) : use a bidirectional RNN
num_layers (int) : number of stacked layers
hidden_size (int) : hidden size of each layer
dropout (float) : dropout value for :obj:`nn.Dropout`
embeddings (:obj:`onmt.modules.Embeddings`): embedding module to use
"""
def __init__(self, rnn_type, bidirectional, num_layers,
hidden_size, dropout=0.0, embeddings=None,
use_bridge=False):
super(RNNEncoder, self).__init__()
assert embeddings is not None
num_directions = 2 if bidirectional else 1
assert hidden_size % num_directions == 0
hidden_size = hidden_size // num_directions
self.embeddings = embeddings
self.rnn, self.no_pack_padded_seq = \
rnn_factory(rnn_type,
input_size=embeddings.embedding_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional)
# Initialize the bridge layer
self.use_bridge = use_bridge
if self.use_bridge:
self._initialize_bridge(rnn_type,
hidden_size,
num_layers)
def forward(self, src, lengths=None, encoder_state=None):
"See :obj:`EncoderBase.forward()`"
self._check_args(src, lengths, encoder_state)
emb = self.embeddings(src)
s_len, batch, emb_dim = emb.size()
packed_emb = emb
if lengths is not None and not self.no_pack_padded_seq:
# Lengths data is wrapped inside a Variable.
lengths = lengths.view(-1).tolist()
packed_emb = pack(emb, lengths)
memory_bank, encoder_final = self.rnn(packed_emb, encoder_state)
if lengths is not None and not self.no_pack_padded_seq:
memory_bank = unpack(memory_bank)[0]
if self.use_bridge:
encoder_final = self._bridge(encoder_final)
return encoder_final, memory_bank
def _initialize_bridge(self, rnn_type,
hidden_size,
num_layers):
# LSTM has hidden and cell state, other only one
number_of_states = 2 if rnn_type == "LSTM" else 1
# Total number of states
self.total_hidden_dim = hidden_size * num_layers
# Build a linear layer for each
self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim,
self.total_hidden_dim,
bias=True)
for i in range(number_of_states)])
def _bridge(self, hidden):
"""
Forward hidden state through bridge
"""
def bottle_hidden(linear, states):
"""
Transform from 3D to 2D, apply linear and return initial size
"""
size = states.size()
result = linear(states.view(-1, self.total_hidden_dim))
return F.relu(result).view(size)
if isinstance(hidden, tuple): # LSTM
outs = tuple([bottle_hidden(layer, hidden[ix])
for ix, layer in enumerate(self.bridge)])
else:
outs = bottle_hidden(self.bridge[0], hidden)
return outs
class RNNDecoderBase(nn.Module):
"""
Base recurrent attention-based decoder class.
Specifies the interface used by different decoder types
and required by :obj:`onmt.Models.NMTModel`.
.. mermaid::
graph BT
A[Input]
subgraph RNN
C[Pos 1]
D[Pos 2]
E[Pos N]
end
G[Decoder State]
H[Decoder State]
I[Outputs]
F[Memory_Bank]
A--emb-->C
A--emb-->D
A--emb-->E
H-->C
C-- attn --- F
D-- attn --- F
E-- attn --- F
C-->I
D-->I
E-->I
E-->G
F---I
Args:
rnn_type (:obj:`str`):
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
bidirectional_encoder (bool) : use with a bidirectional encoder
num_layers (int) : number of stacked layers
hidden_size (int) : hidden size of each layer
attn_type (str) : see :obj:`onmt.modules.GlobalAttention`
coverage_attn (str): see :obj:`onmt.modules.GlobalAttention`
context_gate (str): see :obj:`onmt.modules.ContextGate`
copy_attn (bool): setup a separate copy attention mechanism
dropout (float) : dropout value for :obj:`nn.Dropout`
embeddings (:obj:`onmt.modules.Embeddings`): embedding module to use
"""
def __init__(self, rnn_type, bidirectional_encoder, num_layers,
hidden_size, attn_type="general",
coverage_attn=False, context_gate=None,
copy_attn=False, dropout=0.0, embeddings=None,
reuse_copy_attn=False):
super(RNNDecoderBase, self).__init__()
# Basic attributes.
self.decoder_type = 'rnn'
self.bidirectional_encoder = bidirectional_encoder
self.num_layers = num_layers
self.hidden_size = hidden_size
self.embeddings = embeddings
self.dropout = nn.Dropout(dropout)
# Build the RNN.
self.rnn = self._build_rnn(rnn_type,
input_size=self._input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout)
# Set up the context gate.
self.context_gate = None
if context_gate is not None:
self.context_gate = onmt.modules.context_gate_factory(
context_gate, self._input_size,
hidden_size, hidden_size, hidden_size
)
# Set up the standard attention.
self._coverage = coverage_attn
self.attn = onmt.modules.GlobalAttention(
hidden_size, coverage=coverage_attn,
attn_type=attn_type
)
# Set up a separated copy attention layer, if needed.
self._copy = False
if copy_attn and not reuse_copy_attn:
self.copy_attn = onmt.modules.GlobalAttention(
hidden_size, attn_type=attn_type
)
if copy_attn:
self._copy = True
self._reuse_copy_attn = reuse_copy_attn
def forward(self, tgt, memory_bank, state, memory_lengths=None):
"""
Args:
tgt (`LongTensor`): sequences of padded tokens
`[tgt_len x batch x nfeats]`.
memory_bank (`FloatTensor`): vectors from the encoder
`[src_len x batch x hidden]`.
state (:obj:`onmt.Models.DecoderState`):
decoder state object to initialize the decoder
memory_lengths (`LongTensor`): the padded source lengths
`[batch]`.
Returns:
(`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`):
* decoder_outputs: output from the decoder (after attn)
`[tgt_len x batch x hidden]`.
* decoder_state: final hidden state from the decoder
* attns: distribution over src at each tgt
`[tgt_len x batch x src_len]`.
"""
# Check
assert isinstance(state, RNNDecoderState)
tgt_len, tgt_batch, _ = tgt.size()
_, memory_batch, _ = memory_bank.size()
aeq(tgt_batch, memory_batch)
# END
# Run the forward pass of the RNN.
decoder_final, decoder_outputs, attns = self._run_forward_pass(
tgt, memory_bank, state, memory_lengths=memory_lengths)
# Update the state with the result.
final_output = decoder_outputs[-1]
coverage = None
if "coverage" in attns:
coverage = attns["coverage"][-1].unsqueeze(0)
state.update_state(decoder_final, final_output.unsqueeze(0), coverage)
# Concatenates sequence of tensors along a new dimension.
decoder_outputs = torch.stack(decoder_outputs)
for k in attns:
attns[k] = torch.stack(attns[k])
return decoder_outputs, state, attns
def init_decoder_state(self, src, memory_bank, encoder_final):
def _fix_enc_hidden(h):
# The encoder hidden is (layers*directions) x batch x dim.
# We need to convert it to layers x batch x (directions*dim).
if self.bidirectional_encoder:
h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
return h
if isinstance(encoder_final, tuple): # LSTM
return RNNDecoderState(self.hidden_size,
tuple([_fix_enc_hidden(enc_hid)
for enc_hid in encoder_final]))
else: # GRU
return RNNDecoderState(self.hidden_size,
_fix_enc_hidden(encoder_final))
class StdRNNDecoder(RNNDecoderBase):
"""
Standard fully batched RNN decoder with attention.
Faster implementation, uses CuDNN for implementation.
See :obj:`RNNDecoderBase` for options.
Based around the approach from
"Neural Machine Translation By Jointly Learning To Align and Translate"
:cite:`Bahdanau2015`
Implemented without input_feeding and currently with no `coverage_attn`
or `copy_attn` support.
"""
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None):
"""
Private helper for running the specific RNN forward pass.
Must be overriden by all subclasses.
Args:
tgt (LongTensor): a sequence of input tokens tensors
[len x batch x nfeats].
memory_bank (FloatTensor): output(tensor sequence) from the encoder
RNN of size (src_len x batch x hidden_size).
state (FloatTensor): hidden state from the encoder RNN for
initializing the decoder.
memory_lengths (LongTensor): the source memory_bank lengths.
Returns:
decoder_final (Variable): final hidden state from the decoder.
decoder_outputs ([FloatTensor]): an array of output of every time
step from the decoder.
attns (dict of (str, [FloatTensor]): a dictionary of different
type of attention Tensor array of every time
step from the decoder.
"""
assert not self._copy # TODO, no support yet.
assert not self._coverage # TODO, no support yet.
# Initialize local and return variables.
attns = {}
emb = self.embeddings(tgt)
# Run the forward pass of the RNN.
if isinstance(self.rnn, nn.GRU):
rnn_output, decoder_final = self.rnn(emb, state.hidden[0])
else:
rnn_output, decoder_final = self.rnn(emb, state.hidden)
# Check
tgt_len, tgt_batch, _ = tgt.size()
output_len, output_batch, _ = rnn_output.size()
aeq(tgt_len, output_len)
aeq(tgt_batch, output_batch)
# END
# Calculate the attention.
decoder_outputs, p_attn = self.attn(
rnn_output.transpose(0, 1).contiguous(),
memory_bank.transpose(0, 1),
memory_lengths=memory_lengths
)
attns["std"] = p_attn
# Calculate the context gate.
if self.context_gate is not None:
decoder_outputs = self.context_gate(
emb.view(-1, emb.size(2)),
rnn_output.view(-1, rnn_output.size(2)),
decoder_outputs.view(-1, decoder_outputs.size(2))
)
decoder_outputs = \
decoder_outputs.view(tgt_len, tgt_batch, self.hidden_size)
decoder_outputs = self.dropout(decoder_outputs)
return decoder_final, decoder_outputs, attns
def _build_rnn(self, rnn_type, **kwargs):
rnn, _ = rnn_factory(rnn_type, **kwargs)
return rnn
@property
def _input_size(self):
"""
Private helper returning the number of expected features.
"""
return self.embeddings.embedding_size
class InputFeedRNNDecoder(RNNDecoderBase):
"""
Input feeding based decoder. See :obj:`RNNDecoderBase` for options.
Based around the input feeding approach from
"Effective Approaches to Attention-based Neural Machine Translation"
:cite:`Luong2015`
.. mermaid::
graph BT
A[Input n-1]
AB[Input n]
subgraph RNN
E[Pos n-1]
F[Pos n]
E --> F
end
G[Encoder]
H[Memory_Bank n-1]
A --> E
AB --> F
E --> H
G --> H
"""
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None):
"""
See StdRNNDecoder._run_forward_pass() for description
of arguments and return values.
"""
# Additional args check.
input_feed = state.input_feed.squeeze(0)
input_feed_batch, _ = input_feed.size()
tgt_len, tgt_batch, _ = tgt.size()
aeq(tgt_batch, input_feed_batch)
# END Additional args check.
# Initialize local and return variables.
decoder_outputs = []
attns = {"std": []}
if self._copy:
attns["copy"] = []
if self._coverage:
attns["coverage"] = []
emb = self.embeddings(tgt)
assert emb.dim() == 3 # len x batch x embedding_dim
hidden = state.hidden
coverage = state.coverage.squeeze(0) \
if state.coverage is not None else None
# Input feed concatenates hidden state with
# input at every time step.
for i, emb_t in enumerate(emb.split(1)):
emb_t = emb_t.squeeze(0)
decoder_input = torch.cat([emb_t, input_feed], 1)
rnn_output, hidden = self.rnn(decoder_input, hidden)
decoder_output, p_attn = self.attn(
rnn_output,
memory_bank.transpose(0, 1),
memory_lengths=memory_lengths)
if self.context_gate is not None:
# TODO: context gate should be employed
# instead of second RNN transform.
decoder_output = self.context_gate(
decoder_input, rnn_output, decoder_output
)
decoder_output = self.dropout(decoder_output)
input_feed = decoder_output
decoder_outputs += [decoder_output]
attns["std"] += [p_attn]
# Update the coverage attention.
if self._coverage:
coverage = coverage + p_attn \
if coverage is not None else p_attn
attns["coverage"] += [coverage]
# Run the forward pass of the copy attention layer.
if self._copy and not self._reuse_copy_attn:
_, copy_attn = self.copy_attn(decoder_output,
memory_bank.transpose(0, 1))
attns["copy"] += [copy_attn]
elif self._copy:
attns["copy"] = attns["std"]
# Return result.
return hidden, decoder_outputs, attns
def _build_rnn(self, rnn_type, input_size,
hidden_size, num_layers, dropout):
assert not rnn_type == "SRU", "SRU doesn't support input feed! " \
"Please set -input_feed 0!"
if rnn_type == "LSTM":
stacked_cell = onmt.modules.StackedLSTM
else:
stacked_cell = onmt.modules.StackedGRU
return stacked_cell(num_layers, input_size,
hidden_size, dropout)
@property
def _input_size(self):
"""
Using input feed by concatenating input with attention vectors.
"""
return self.embeddings.embedding_size + self.hidden_size
class NMTModel(nn.Module):
"""
Core trainable object in OpenNMT. Implements a trainable interface
for a simple, generic encoder + decoder model.
Args:
encoder (:obj:`EncoderBase`): an encoder object
decoder (:obj:`RNNDecoderBase`): a decoder object
multi<gpu (bool): setup for multigpu support
"""
def __init__(self, encoder, decoder, multigpu=False):
self.multigpu = multigpu
super(NMTModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, tgt, lengths, dec_state=None):
"""Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state.
Args:
src (:obj:`Tensor`):
a source sequence passed to encoder.
typically for inputs this will be a padded :obj:`LongTensor`
of size `[len x batch x features]`. however, may be an
image or other generic input depending on encoder.
tgt (:obj:`LongTensor`):
a target sequence of size `[tgt_len x batch]`.
lengths(:obj:`LongTensor`): the src lengths, pre-padding `[batch]`.
dec_state (:obj:`DecoderState`, optional): initial decoder state
Returns:
(:obj:`FloatTensor`, `dict`, :obj:`onmt.Models.DecoderState`):
* decoder output `[tgt_len x batch x hidden]`
* dictionary attention dists of `[tgt_len x batch x src_len]`
* final decoder state
"""
tgt = tgt[:-1] # exclude last target from inputs
enc_final, memory_bank = self.encoder(src, lengths)
enc_state = \
self.decoder.init_decoder_state(src, memory_bank, enc_final)
decoder_outputs, dec_state, attns = \
self.decoder(tgt, memory_bank,
enc_state if dec_state is None
else dec_state,
memory_lengths=lengths)
if self.multigpu:
# Not yet supported on multi-gpu
dec_state = None
attns = None
return decoder_outputs, attns, dec_state
class DecoderState(object):
"""Interface for grouping together the current state of a recurrent
decoder. In the simplest case just represents the hidden state of
the model. But can also be used for implementing various forms of
input_feeding and non-recurrent models.
Modules need to implement this to utilize beam search decoding.
"""
def detach(self):
for h in self._all:
if h is not None:
h.detach_()
def beam_update(self, idx, positions, beam_size):
for e in self._all:
sizes = e.size()
br = sizes[1]
if len(sizes) == 3:
sent_states = e.view(sizes[0], beam_size, br // beam_size,
sizes[2])[:, :, idx]
else:
sent_states = e.view(sizes[0], beam_size,
br // beam_size,
sizes[2],
sizes[3])[:, :, idx]
sent_states.data.copy_(
sent_states.data.index_select(1, positions))
class RNNDecoderState(DecoderState):
def __init__(self, hidden_size, rnnstate):
"""
Args:
hidden_size (int): the size of hidden layer of the decoder.
rnnstate: final hidden state from the encoder.
transformed to shape: layers x batch x (directions*dim).
"""
if not isinstance(rnnstate, tuple):
self.hidden = (rnnstate,)
else:
self.hidden = rnnstate
self.coverage = None
# Init the input feed.
batch_size = self.hidden[0].size(1)
h_size = (batch_size, hidden_size)
self.input_feed = Variable(self.hidden[0].data.new(*h_size).zero_(),
requires_grad=False).unsqueeze(0)
@property
def _all(self):
return self.hidden + (self.input_feed,)
def update_state(self, rnnstate, input_feed, coverage):
if not isinstance(rnnstate, tuple):
self.hidden = (rnnstate,)
else:
self.hidden = rnnstate
self.input_feed = input_feed
self.coverage = coverage
def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
vars = [Variable(e.data.repeat(1, beam_size, 1), volatile=True)
for e in self._all]
self.hidden = tuple(vars[:-1])
self.input_feed = vars[-1]