Skip to content

Commit 9429425

Browse files
committed
Update seq2seq.py
1 parent 887ed38 commit 9429425

File tree

1 file changed

+1
-149
lines changed

1 file changed

+1
-149
lines changed

src/11_seq2seq/modules/seq2seq.py

Lines changed: 1 addition & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from torch.nn.utils.rnn import pack_padded_sequence as pack
44
from torch.nn.utils.rnn import pad_packed_sequence as unpack
55

6-
import simple_nmt.data_loader as data_loader
7-
from simple_nmt.search import SingleBeamSearchBoard
6+
import modules.data_loader as data_loader
87

98

109
class Attention(nn.Module):
@@ -370,150 +369,3 @@ def search(self, src, is_greedy=True, max_length=255):
370369
# |indice| = (batch_size, length)
371370

372371
return y_hats, indice
373-
374-
#@profile
375-
def batch_beam_search(
376-
self,
377-
src,
378-
beam_size=5,
379-
max_length=255,
380-
n_best=1,
381-
length_penalty=.2
382-
):
383-
mask, x_length = None, None
384-
385-
if isinstance(src, tuple):
386-
x, x_length = src
387-
mask = self.generate_mask(x, x_length)
388-
# |mask| = (batch_size, length)
389-
else:
390-
x = src
391-
batch_size = x.size(0)
392-
393-
emb_src = self.emb_src(x)
394-
h_src, h_0_tgt = self.encoder((emb_src, x_length))
395-
# |h_src| = (batch_size, length, hidden_size)
396-
h_0_tgt = self.fast_merge_encoder_hiddens(h_0_tgt)
397-
398-
# initialize 'SingleBeamSearchBoard' as many as batch_size
399-
boards = [SingleBeamSearchBoard(
400-
h_src.device,
401-
{
402-
'hidden_state': {
403-
'init_status': h_0_tgt[0][:, i, :].unsqueeze(1),
404-
'batch_dim_index': 1,
405-
}, # |hidden_state| = (n_layers, batch_size, hidden_size)
406-
'cell_state': {
407-
'init_status': h_0_tgt[1][:, i, :].unsqueeze(1),
408-
'batch_dim_index': 1,
409-
}, # |cell_state| = (n_layers, batch_size, hidden_size)
410-
'h_t_1_tilde': {
411-
'init_status': None,
412-
'batch_dim_index': 0,
413-
}, # |h_t_1_tilde| = (batch_size, 1, hidden_size)
414-
},
415-
beam_size=beam_size,
416-
max_length=max_length,
417-
) for i in range(batch_size)]
418-
is_done = [board.is_done() for board in boards]
419-
420-
length = 0
421-
# Run loop while sum of 'is_done' is smaller than batch_size,
422-
# or length is still smaller than max_length.
423-
while sum(is_done) < batch_size and length <= max_length:
424-
# current_batch_size = sum(is_done) * beam_size
425-
426-
# Initialize fabricated variables.
427-
# As far as batch-beam-search is running,
428-
# temporary batch-size for fabricated mini-batch is
429-
# 'beam_size'-times bigger than original batch_size.
430-
fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], []
431-
fab_h_src, fab_mask = [], []
432-
433-
# Build fabricated mini-batch in non-parallel way.
434-
# This may cause a bottle-neck.
435-
for i, board in enumerate(boards):
436-
# Batchify if the inference for the sample is still not finished.
437-
if board.is_done() == 0:
438-
y_hat_i, prev_status = board.get_batch()
439-
hidden_i = prev_status['hidden_state']
440-
cell_i = prev_status['cell_state']
441-
h_t_tilde_i = prev_status['h_t_1_tilde']
442-
443-
fab_input += [y_hat_i]
444-
fab_hidden += [hidden_i]
445-
fab_cell += [cell_i]
446-
fab_h_src += [h_src[i, :, :]] * beam_size
447-
fab_mask += [mask[i, :]] * beam_size
448-
if h_t_tilde_i is not None:
449-
fab_h_t_tilde += [h_t_tilde_i]
450-
else:
451-
fab_h_t_tilde = None
452-
453-
# Now, concatenate list of tensors.
454-
fab_input = torch.cat(fab_input, dim=0)
455-
fab_hidden = torch.cat(fab_hidden, dim=1)
456-
fab_cell = torch.cat(fab_cell, dim=1)
457-
fab_h_src = torch.stack(fab_h_src)
458-
fab_mask = torch.stack(fab_mask)
459-
if fab_h_t_tilde is not None:
460-
fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0)
461-
# |fab_input| = (current_batch_size, 1)
462-
# |fab_hidden| = (n_layers, current_batch_size, hidden_size)
463-
# |fab_cell| = (n_layers, current_batch_size, hidden_size)
464-
# |fab_h_src| = (current_batch_size, length, hidden_size)
465-
# |fab_mask| = (current_batch_size, length)
466-
# |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
467-
468-
emb_t = self.emb_dec(fab_input)
469-
# |emb_t| = (current_batch_size, 1, word_vec_size)
470-
471-
fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(emb_t,
472-
fab_h_t_tilde,
473-
(fab_hidden, fab_cell))
474-
# |fab_decoder_output| = (current_batch_size, 1, hidden_size)
475-
context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask)
476-
# |context_vector| = (current_batch_size, 1, hidden_size)
477-
fab_h_t_tilde = self.tanh(self.concat(torch.cat([fab_decoder_output,
478-
context_vector
479-
], dim=-1)))
480-
# |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
481-
y_hat = self.generator(fab_h_t_tilde)
482-
# |y_hat| = (current_batch_size, 1, output_size)
483-
484-
# separate the result for each sample.
485-
# fab_hidden[:, begin:end, :] = (n_layers, beam_size, hidden_size)
486-
# fab_cell[:, begin:end, :] = (n_layers, beam_size, hidden_size)
487-
# fab_h_t_tilde[begin:end] = (beam_size, 1, hidden_size)
488-
cnt = 0
489-
for board in boards:
490-
if board.is_done() == 0:
491-
# Decide a range of each sample.
492-
begin = cnt * beam_size
493-
end = begin + beam_size
494-
495-
# pick k-best results for each sample.
496-
board.collect_result(
497-
y_hat[begin:end],
498-
{
499-
'hidden_state': fab_hidden[:, begin:end, :],
500-
'cell_state' : fab_cell[:, begin:end, :],
501-
'h_t_1_tilde' : fab_h_t_tilde[begin:end],
502-
},
503-
)
504-
cnt += 1
505-
506-
is_done = [board.is_done() for board in boards]
507-
length += 1
508-
509-
# pick n-best hypothesis.
510-
batch_sentences, batch_probs = [], []
511-
512-
# Collect the results.
513-
for i, board in enumerate(boards):
514-
sentences, probs = board.get_n_best(n_best, length_penalty=length_penalty)
515-
516-
batch_sentences += [sentences]
517-
batch_probs += [probs]
518-
519-
return batch_sentences, batch_probs

0 commit comments

Comments
 (0)