|
3 | 3 | from torch.nn.utils.rnn import pack_padded_sequence as pack |
4 | 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack |
5 | 5 |
|
6 | | -import simple_nmt.data_loader as data_loader |
7 | | -from simple_nmt.search import SingleBeamSearchBoard |
| 6 | +import modules.data_loader as data_loader |
8 | 7 |
|
9 | 8 |
|
10 | 9 | class Attention(nn.Module): |
@@ -370,150 +369,3 @@ def search(self, src, is_greedy=True, max_length=255): |
370 | 369 | # |indice| = (batch_size, length) |
371 | 370 |
|
372 | 371 | return y_hats, indice |
373 | | - |
374 | | - #@profile |
375 | | - def batch_beam_search( |
376 | | - self, |
377 | | - src, |
378 | | - beam_size=5, |
379 | | - max_length=255, |
380 | | - n_best=1, |
381 | | - length_penalty=.2 |
382 | | - ): |
383 | | - mask, x_length = None, None |
384 | | - |
385 | | - if isinstance(src, tuple): |
386 | | - x, x_length = src |
387 | | - mask = self.generate_mask(x, x_length) |
388 | | - # |mask| = (batch_size, length) |
389 | | - else: |
390 | | - x = src |
391 | | - batch_size = x.size(0) |
392 | | - |
393 | | - emb_src = self.emb_src(x) |
394 | | - h_src, h_0_tgt = self.encoder((emb_src, x_length)) |
395 | | - # |h_src| = (batch_size, length, hidden_size) |
396 | | - h_0_tgt = self.fast_merge_encoder_hiddens(h_0_tgt) |
397 | | - |
398 | | - # initialize 'SingleBeamSearchBoard' as many as batch_size |
399 | | - boards = [SingleBeamSearchBoard( |
400 | | - h_src.device, |
401 | | - { |
402 | | - 'hidden_state': { |
403 | | - 'init_status': h_0_tgt[0][:, i, :].unsqueeze(1), |
404 | | - 'batch_dim_index': 1, |
405 | | - }, # |hidden_state| = (n_layers, batch_size, hidden_size) |
406 | | - 'cell_state': { |
407 | | - 'init_status': h_0_tgt[1][:, i, :].unsqueeze(1), |
408 | | - 'batch_dim_index': 1, |
409 | | - }, # |cell_state| = (n_layers, batch_size, hidden_size) |
410 | | - 'h_t_1_tilde': { |
411 | | - 'init_status': None, |
412 | | - 'batch_dim_index': 0, |
413 | | - }, # |h_t_1_tilde| = (batch_size, 1, hidden_size) |
414 | | - }, |
415 | | - beam_size=beam_size, |
416 | | - max_length=max_length, |
417 | | - ) for i in range(batch_size)] |
418 | | - is_done = [board.is_done() for board in boards] |
419 | | - |
420 | | - length = 0 |
421 | | - # Run loop while sum of 'is_done' is smaller than batch_size, |
422 | | - # or length is still smaller than max_length. |
423 | | - while sum(is_done) < batch_size and length <= max_length: |
424 | | - # current_batch_size = sum(is_done) * beam_size |
425 | | - |
426 | | - # Initialize fabricated variables. |
427 | | - # As far as batch-beam-search is running, |
428 | | - # temporary batch-size for fabricated mini-batch is |
429 | | - # 'beam_size'-times bigger than original batch_size. |
430 | | - fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], [] |
431 | | - fab_h_src, fab_mask = [], [] |
432 | | - |
433 | | - # Build fabricated mini-batch in non-parallel way. |
434 | | - # This may cause a bottle-neck. |
435 | | - for i, board in enumerate(boards): |
436 | | - # Batchify if the inference for the sample is still not finished. |
437 | | - if board.is_done() == 0: |
438 | | - y_hat_i, prev_status = board.get_batch() |
439 | | - hidden_i = prev_status['hidden_state'] |
440 | | - cell_i = prev_status['cell_state'] |
441 | | - h_t_tilde_i = prev_status['h_t_1_tilde'] |
442 | | - |
443 | | - fab_input += [y_hat_i] |
444 | | - fab_hidden += [hidden_i] |
445 | | - fab_cell += [cell_i] |
446 | | - fab_h_src += [h_src[i, :, :]] * beam_size |
447 | | - fab_mask += [mask[i, :]] * beam_size |
448 | | - if h_t_tilde_i is not None: |
449 | | - fab_h_t_tilde += [h_t_tilde_i] |
450 | | - else: |
451 | | - fab_h_t_tilde = None |
452 | | - |
453 | | - # Now, concatenate list of tensors. |
454 | | - fab_input = torch.cat(fab_input, dim=0) |
455 | | - fab_hidden = torch.cat(fab_hidden, dim=1) |
456 | | - fab_cell = torch.cat(fab_cell, dim=1) |
457 | | - fab_h_src = torch.stack(fab_h_src) |
458 | | - fab_mask = torch.stack(fab_mask) |
459 | | - if fab_h_t_tilde is not None: |
460 | | - fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0) |
461 | | - # |fab_input| = (current_batch_size, 1) |
462 | | - # |fab_hidden| = (n_layers, current_batch_size, hidden_size) |
463 | | - # |fab_cell| = (n_layers, current_batch_size, hidden_size) |
464 | | - # |fab_h_src| = (current_batch_size, length, hidden_size) |
465 | | - # |fab_mask| = (current_batch_size, length) |
466 | | - # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) |
467 | | - |
468 | | - emb_t = self.emb_dec(fab_input) |
469 | | - # |emb_t| = (current_batch_size, 1, word_vec_size) |
470 | | - |
471 | | - fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(emb_t, |
472 | | - fab_h_t_tilde, |
473 | | - (fab_hidden, fab_cell)) |
474 | | - # |fab_decoder_output| = (current_batch_size, 1, hidden_size) |
475 | | - context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask) |
476 | | - # |context_vector| = (current_batch_size, 1, hidden_size) |
477 | | - fab_h_t_tilde = self.tanh(self.concat(torch.cat([fab_decoder_output, |
478 | | - context_vector |
479 | | - ], dim=-1))) |
480 | | - # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) |
481 | | - y_hat = self.generator(fab_h_t_tilde) |
482 | | - # |y_hat| = (current_batch_size, 1, output_size) |
483 | | - |
484 | | - # separate the result for each sample. |
485 | | - # fab_hidden[:, begin:end, :] = (n_layers, beam_size, hidden_size) |
486 | | - # fab_cell[:, begin:end, :] = (n_layers, beam_size, hidden_size) |
487 | | - # fab_h_t_tilde[begin:end] = (beam_size, 1, hidden_size) |
488 | | - cnt = 0 |
489 | | - for board in boards: |
490 | | - if board.is_done() == 0: |
491 | | - # Decide a range of each sample. |
492 | | - begin = cnt * beam_size |
493 | | - end = begin + beam_size |
494 | | - |
495 | | - # pick k-best results for each sample. |
496 | | - board.collect_result( |
497 | | - y_hat[begin:end], |
498 | | - { |
499 | | - 'hidden_state': fab_hidden[:, begin:end, :], |
500 | | - 'cell_state' : fab_cell[:, begin:end, :], |
501 | | - 'h_t_1_tilde' : fab_h_t_tilde[begin:end], |
502 | | - }, |
503 | | - ) |
504 | | - cnt += 1 |
505 | | - |
506 | | - is_done = [board.is_done() for board in boards] |
507 | | - length += 1 |
508 | | - |
509 | | - # pick n-best hypothesis. |
510 | | - batch_sentences, batch_probs = [], [] |
511 | | - |
512 | | - # Collect the results. |
513 | | - for i, board in enumerate(boards): |
514 | | - sentences, probs = board.get_n_best(n_best, length_penalty=length_penalty) |
515 | | - |
516 | | - batch_sentences += [sentences] |
517 | | - batch_probs += [probs] |
518 | | - |
519 | | - return batch_sentences, batch_probs |
0 commit comments