Skip to content

Commit de856f2

Browse files
committed
seq2seq BeamSearch 구현
1 parent 7d6959b commit de856f2

File tree

3 files changed

+375
-409
lines changed

3 files changed

+375
-409
lines changed

src/11_seq2seq/modules/search.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
from operator import itemgetter
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
import modules.data_loader as data_loader
7+
8+
LENGTH_PENALTY = .2
9+
MIN_LENGTH = 5
10+
11+
12+
class SingleBeamSearchBoard():
13+
14+
def __init__(
15+
self,
16+
device,
17+
prev_status_config,
18+
beam_size=5,
19+
max_length=255,
20+
):
21+
self.beam_size = beam_size
22+
self.max_length = max_length
23+
24+
self.device = device
25+
# 각 타임 스텝의 Word Index(즉, 최종 예측 단어들) * beam_size
26+
# 처음에는 모두 <BOS>므로 초기화
27+
self.word_indice = [torch.LongTensor(beam_size).zero_().to(self.device) + data_loader.BOS]
28+
# 각 타임 스텝의 Word들이 선정된 Beam Index
29+
# 처음에는 아무것도 선정되지 않았기에 -1로 초기화
30+
self.beam_indice = [torch.LongTensor(beam_size).zero_().to(self.device) - 1]
31+
# 각 Beam들의 누적 확률 값
32+
# 처음에는 [0, -inf, -inf, ...]로 초기화
33+
self.cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')] * (beam_size - 1)).to(self.device)]
34+
# 각 빔이 현재 EOS에 도달했는지 여부
35+
# 1 if it is done else 0
36+
self.masks = [torch.BoolTensor(beam_size).zero_().to(self.device)]
37+
# We don't need to remember every time-step of hidden states:
38+
# prev_hidden, prev_cell, prev_h_t_tilde
39+
# What we need is remember just last one.
40+
41+
'''
42+
각 빔의 이전 hidden, cell, h_tilde를 저장해두는 공간
43+
항상 마지막 타임스텝만 보관하면 됨
44+
45+
단 처음에는 그냥 넘겨받은 hidden, cell, h_tilde를 beam_size만큼 늘려줌
46+
h_tilde의 경우, 처음에 None이므로 예외처리
47+
'''
48+
self.prev_status = {}
49+
self.batch_dims = {}
50+
for prev_status_name, each_config in prev_status_config.items():
51+
init_status = each_config['init_status']
52+
batch_dim_index = each_config['batch_dim_index']
53+
if init_status is not None:
54+
self.prev_status[prev_status_name] = torch.cat([init_status] * beam_size,
55+
dim=batch_dim_index)
56+
else:
57+
self.prev_status[prev_status_name] = None
58+
self.batch_dims[prev_status_name] = batch_dim_index
59+
60+
self.current_time_step = 0
61+
self.done_cnt = 0
62+
63+
def get_length_penalty(
64+
self,
65+
length,
66+
alpha=LENGTH_PENALTY,
67+
min_length=MIN_LENGTH,
68+
):
69+
# Calculate length-penalty,
70+
# because shorter sentence usually have bigger probability.
71+
# In fact, we represent this as log-probability, which is negative value.
72+
# Thus, we need to multiply bigger penalty for shorter one.
73+
p = ((min_length + 1) / (min_length + length))**alpha
74+
75+
return p
76+
77+
def is_done(self):
78+
# Return 1, if we had EOS more than 'beam_size'-times.
79+
if self.done_cnt >= self.beam_size:
80+
return 1
81+
return 0
82+
83+
def get_batch(self):
84+
'''
85+
현재 빔에서 가장 마지막 스텝의 워드 인덱스들을 가져옴
86+
처음에는, 당연히 모두 BOS 일것임
87+
그 후로는 이전에 예측했던 TopK의 단어들을 주게 될것임
88+
'''
89+
y_hat = self.word_indice[-1].unsqueeze(-1)
90+
# |y_hat| = (beam_size, 1)
91+
# if model != transformer:
92+
# |hidden| = |cell| = (n_layers, beam_size, hidden_size)
93+
# |h_t_tilde| = (beam_size, 1, hidden_size) or None
94+
# else:
95+
# |prev_state_i| = (beam_size, length, hidden_size),
96+
# where i is an index of layer.
97+
return y_hat, self.prev_status
98+
99+
#@profile
100+
def collect_result(self, y_hat, prev_status):
101+
'''
102+
y_hat: 현재 타입스텝의 각 beam마다 예측한 단어
103+
pre_status: 현재 타입스텝에서 함께 나왔던 hidden, cell, h_tilde
104+
넣을때, beam 채로 넣었으므로 그대로 다시 나오게 됨
105+
'''
106+
107+
# |y_hat| = (beam_size, 1, output_size)
108+
# |hidden| = |cell| = (n_layers, beam_size, hidden_size)
109+
# |h_t_tilde| = (beam_size, 1, hidden_size)
110+
output_size = y_hat.size(-1)
111+
112+
self.current_time_step += 1
113+
114+
# 누적 확률 값을 계산함
115+
# (beam_size) --> (beam_size, 1, 1) --> (beam_size, 1, output_size)
116+
# 이미 예측이 끝난 경우, 즉 EOS인 경우, 확률값에 -inf을 덮어씀
117+
cumulative_prob = self.cumulative_probs[-1].masked_fill_(self.masks[-1], -float('inf'))
118+
# 각 단어(output_Size)만큼의 누적확률 값을 계산하기 위해
119+
# (beam_size, 1, output_size)의 크기로 늘려줌
120+
# 그 후, 입력받은 y_hat과 더해서 최종 누적 확률 값 산출
121+
# 하지만 이떄, 맨처음 cumulative_prob가 (0, -inf, -inf)이므로
122+
# 처음에는 첫 번째 빔에서만 모든 결과가 나오게 될 것임
123+
cumulative_prob = y_hat + cumulative_prob.view(-1, 1, 1).expand(self.beam_size, 1, output_size)
124+
# |cumulative_prob| = (beam_size, 1, output_size)
125+
126+
# cumulative_prob를 (beam_size * output_size,)로
127+
# flatten 해준후 확률이 높은 순으로 정렬
128+
# top_indice에는 원래 정렬되기전 index가 유지됨
129+
top_log_prob, top_indice = cumulative_prob.view(-1).sort(descending=True)
130+
# 그후, TopK개만큼 잘라냄
131+
top_log_prob, top_indice = top_log_prob[:self.beam_size], top_indice[:self.beam_size]
132+
# |top_log_prob| = (beam_size,)
133+
# |top_indice| = (beam_size,)
134+
# top_log_prob: 각 단어에 대한 확률 값
135+
# top_indice: 각 단어들의 index -> 해당 인덱스를 이용해서
136+
# 어느 빔의 어느 단어가 인지를 추적할 수 있음
137+
138+
139+
# 모든 top_indice를 output_size로 나눈 나머지를 구함으로써
140+
# 각 top_indice가 원래 가르키던 word_Index가 튀어나오게 됨
141+
self.word_indice += [top_indice.fmod(output_size)]
142+
# 모든 top_indice를 output_size로 나눔으로써
143+
# 각 top_indice가 원래 가르키던 Beam_index가 나오게 됨
144+
# 이로써, 최종적으로 topK에 (어떤 빔)에서 나와서 (어떤 단어)가 선정되었는지 식별
145+
self.beam_indice += [top_indice.div(float(output_size)).long()]
146+
147+
# 이번 스텝에서 구한 누적 확률값을 객체에 갱신
148+
self.cumulative_probs += [top_log_prob]
149+
# 이번 결과를 보며, EOS가 나온 곳을 mask 처리
150+
self.masks += [torch.eq(self.word_indice[-1], data_loader.EOS)]
151+
# 마스크 결과를 바탕으로 done_cnt 캐싱
152+
self.done_cnt += self.masks[-1].float().sum()
153+
154+
# 현재 타임스텝에서 도출된 각종 hidden, cell, h_tilde 값을
155+
# 객체에 저장해야 함. -> 이후 get_batch에서 호출될때 사용
156+
# 단 이때, topK로 선정된 Beam_index의 hidden, cell, h_tilde만 가지고감
157+
for prev_status_name, prev_status in prev_status.items():
158+
self.prev_status[prev_status_name] = torch.index_select(
159+
prev_status,
160+
dim=self.batch_dims[prev_status_name],
161+
index=self.beam_indice[-1]
162+
).contiguous()
163+
164+
def get_n_best(self, n=1, length_penalty=.2):
165+
'''
166+
이때까지의 Beam Board를 찾아보며,
167+
가장 확률 값이 높았던 N개의 문장 추출
168+
'''
169+
sentences, probs, founds = [], [], []
170+
171+
'''
172+
mask 여부를 통해, EOS 즉, 온전히 끝난 문장을 탐색
173+
찾았다면, 해당 문장의 EOS(끝) 인덱스와 마지막으로 나왔던 beam 인덱스,
174+
그리고 그 당시에 누적 확률값을 저장
175+
'''
176+
for t in range(len(self.word_indice)): # for each time-step,
177+
for b in range(self.beam_size): # for each beam,
178+
if self.masks[t][b] == 1: # if we had EOS on this time-step and beam,
179+
# Take a record of penaltified log-proability.
180+
probs += [self.cumulative_probs[t][b] * self.get_length_penalty(t, alpha=length_penalty)]
181+
founds += [(t, b)]
182+
183+
# 만약에, EOS는 아니지만, max_length에 도달해버려 끊겨버린 경우도 수집해옴
184+
# Also, collect log-probability from last time-step, for the case of EOS is not shown.
185+
for b in range(self.beam_size):
186+
if self.cumulative_probs[-1][b] != -float('inf'): # If this beam does not have EOS,
187+
if not (len(self.cumulative_probs) - 1, b) in founds:
188+
probs += [self.cumulative_probs[-1][b] * self.get_length_penalty(len(self.cumulative_probs),
189+
alpha=length_penalty)]
190+
founds += [(t, b)]
191+
192+
# Sort and take n-best.
193+
# 갖고온 문장의 EOS 인덱스를 확률과 묶어서 내림차순 정렬후, N개를 자름
194+
sorted_founds_with_probs = sorted(
195+
zip(founds, probs),
196+
key=itemgetter(1),
197+
reverse=True,
198+
)[:n]
199+
probs = []
200+
201+
'''
202+
정렬된 각 인덱스(EOS)부터 문장을 역으로 내려가며 단어를 수집함
203+
이때, 단어가 beam을 계속해서 옮겨다니며 선정했을 것이기에
204+
반대로, 자신이 나왔던 beam의 단어를 하나씩 추적하며 내려가야 함
205+
'''
206+
for (end_index, b), prob in sorted_founds_with_probs:
207+
sentence = []
208+
209+
# Trace from the end.
210+
for t in range(end_index, 0, -1):
211+
sentence = [self.word_indice[t][b]] + sentence
212+
b = self.beam_indice[t][b]
213+
214+
sentences += [sentence]
215+
probs += [prob]
216+
217+
return sentences, probs

src/11_seq2seq/modules/seq2seq.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.nn.utils.rnn import pack_padded_sequence as pack
44
from torch.nn.utils.rnn import pad_packed_sequence as unpack
55
import modules.data_loader as data_loader
6-
6+
from modules.search import SingleBeamSearchBoard
77

88
class Encoder(nn.Module):
99

@@ -404,4 +404,160 @@ def search(self, src, is_greedy=True, max_length=255):
404404
y_hats = torch.cat(y_hats, dim=1)
405405
indice = torch.cat(indice, dim=1)
406406

407-
return y_hats, indice
407+
return y_hats, indice
408+
409+
def batch_beam_search(
410+
self,
411+
src,
412+
beam_size=5,
413+
max_length=255,
414+
n_best=1,
415+
length_penalty=.2
416+
):
417+
mask, x_length = None, None
418+
419+
if isinstance(src, tuple):
420+
x, x_length = src
421+
mask = self.generate_mask(x, x_length)
422+
# |mask| = (batch_size, length)
423+
else:
424+
x = src
425+
batch_size = x.size(0)
426+
427+
emb_src = self.emb_src(x)
428+
h_src, h_0_tgt = self.encoder((emb_src, x_length))
429+
# |h_src| = (batch_size, length, hidden_size)
430+
h_0_tgt = self.merge_encoder_hiddens(h_0_tgt)
431+
432+
'''
433+
initialize 'SingleBeamSearchBoard'
434+
각 배치별로, beam_size만큼 페이크 배치를 생성해주는 클래스 초기화
435+
hidden_state: 인코더에서 넘어온 히든 스테이트
436+
cell_state: 인코더에서 넘어온 셀 스테이트
437+
h_t_1_tilde: 이전 스텝의 예측값(input feeding),
438+
처음에는 없으므로 None
439+
'''
440+
boards = [SingleBeamSearchBoard(
441+
h_src.device,
442+
{
443+
'hidden_state': {
444+
'init_status': h_0_tgt[0][:, i, :].unsqueeze(1),
445+
'batch_dim_index': 1,
446+
}, # |hidden_state| = (n_layers, batch_size, hidden_size)
447+
'cell_state': {
448+
'init_status': h_0_tgt[1][:, i, :].unsqueeze(1),
449+
'batch_dim_index': 1,
450+
}, # |cell_state| = (n_layers, batch_size, hidden_size)
451+
'h_t_1_tilde': {
452+
'init_status': None,
453+
'batch_dim_index': 0,
454+
}, # |h_t_1_tilde| = (batch_size, 1, hidden_size)
455+
},
456+
beam_size=beam_size,
457+
max_length=max_length,
458+
) for i in range(batch_size)]
459+
# 각 보드(batch)들이 예측이 끝났는지 여부
460+
# 처음에는 당연히 전부 0으로 이루어짐
461+
is_done = [board.is_done() for board in boards]
462+
463+
length = 0
464+
# is_done의 합이 Batch_size를 넘을때까지 반복
465+
while sum(is_done) < batch_size and length <= max_length:
466+
# current_batch_size = sum(is_done) * beam_size
467+
468+
# Initialize fabricated variables.
469+
# As far as batch-beam-search is running,
470+
# temporary batch-size for fabricated mini-batch is
471+
# 'beam_size'-times bigger than original batch_size.
472+
fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], []
473+
fab_h_src, fab_mask = [], []
474+
475+
# 각 input들을 beam_size 만큼 늘려서 가짜 batch_size 생성
476+
# input, hidden, cell, h_t_tilde는 이미 보드에서 늘려진 상태
477+
# h_src, mask만 그대로 expand 해주면 됨
478+
for i, board in enumerate(boards):
479+
# Batchify if the inference for the sample is still not finished.
480+
if board.is_done() == 0:
481+
# 여기서 현재 타임스텝에 필요한 가짜 batch 데이터 반환
482+
y_hat_i, prev_status = board.get_batch()
483+
hidden_i = prev_status['hidden_state']
484+
cell_i = prev_status['cell_state']
485+
h_t_tilde_i = prev_status['h_t_1_tilde']
486+
487+
fab_input += [y_hat_i]
488+
fab_hidden += [hidden_i]
489+
fab_cell += [cell_i]
490+
fab_h_src += [h_src[i, :, :]] * beam_size
491+
fab_mask += [mask[i, :]] * beam_size
492+
if h_t_tilde_i is not None:
493+
fab_h_t_tilde += [h_t_tilde_i]
494+
else:
495+
fab_h_t_tilde = None
496+
497+
fab_input = torch.cat(fab_input, dim=0)
498+
fab_hidden = torch.cat(fab_hidden, dim=1)
499+
fab_cell = torch.cat(fab_cell, dim=1)
500+
fab_h_src = torch.stack(fab_h_src)
501+
fab_mask = torch.stack(fab_mask)
502+
if fab_h_t_tilde is not None:
503+
fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0)
504+
# |fab_input| = (current_batch_size, 1)
505+
# |fab_hidden| = (n_layers, current_batch_size, hidden_size)
506+
# |fab_cell| = (n_layers, current_batch_size, hidden_size)
507+
# |fab_h_src| = (current_batch_size, length, hidden_size)
508+
# |fab_mask| = (current_batch_size, length)
509+
# |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
510+
511+
emb_t = self.emb_dec(fab_input)
512+
# |emb_t| = (current_batch_size, 1, word_vec_size)
513+
514+
fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(emb_t,
515+
fab_h_t_tilde,
516+
(fab_hidden, fab_cell))
517+
# |fab_decoder_output| = (current_batch_size, 1, hidden_size)
518+
context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask)
519+
# |context_vector| = (current_batch_size, 1, hidden_size)
520+
fab_h_t_tilde = self.tanh(self.concat(torch.cat([fab_decoder_output,
521+
context_vector
522+
], dim=-1)))
523+
# |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
524+
y_hat = self.generator(fab_h_t_tilde)
525+
# |y_hat| = (current_batch_size, 1, output_size)
526+
527+
# 디코더에서는 그대로 한 batch인듯이 병렬연산을 해준뒤,
528+
# 각 board에 다시 beam_size만큼 찢어서 보내줌
529+
# fab_hidden[:, begin:end, :] = (n_layers, beam_size, hidden_size)
530+
# fab_cell[:, begin:end, :] = (n_layers, beam_size, hidden_size)
531+
# fab_h_t_tilde[begin:end] = (beam_size, 1, hidden_size)
532+
cnt = 0
533+
for board in boards:
534+
if board.is_done() == 0:
535+
# Decide a range of each sample.
536+
begin = cnt * beam_size
537+
end = begin + beam_size
538+
539+
# pick k-best results for each sample.
540+
board.collect_result(
541+
y_hat[begin:end],
542+
{
543+
'hidden_state': fab_hidden[:, begin:end, :],
544+
'cell_state' : fab_cell[:, begin:end, :],
545+
'h_t_1_tilde' : fab_h_t_tilde[begin:end],
546+
},
547+
)
548+
cnt += 1
549+
550+
is_done = [board.is_done() for board in boards]
551+
length += 1
552+
553+
# pick n-best hypothesis.
554+
batch_sentences, batch_probs = [], []
555+
556+
# Collect the results.
557+
for i, board in enumerate(boards):
558+
sentences, probs = board.get_n_best(n_best, length_penalty=length_penalty)
559+
560+
batch_sentences += [sentences]
561+
batch_probs += [probs]
562+
563+
return batch_sentences, batch_probs

0 commit comments

Comments
 (0)