Skip to content

Commit f70c52a

Browse files
committed
transformer search 구현
1 parent de856f2 commit f70c52a

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

src/11_seq2seq/modules/search.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ def __init__(
2626
# 처음에는 모두 <BOS>므로 초기화
2727
self.word_indice = [torch.LongTensor(beam_size).zero_().to(self.device) + data_loader.BOS]
2828
# 각 타임 스텝의 Word들이 선정된 Beam Index
29+
# 현재 word_indice들의 단어가 각각 어느 Beam에서 선정되었는지 체크
2930
# 처음에는 아무것도 선정되지 않았기에 -1로 초기화
3031
self.beam_indice = [torch.LongTensor(beam_size).zero_().to(self.device) - 1]
3132
# 각 Beam들의 누적 확률 값
32-
# 처음에는 [0, -inf, -inf, ...]로 초기화
33+
# 처음에는 모두 BOS 이므로 모든 빔의 결과가 같을 것임.
34+
# 따라서 첫 번째 빔만 선정하게 하기 위해 [0, -inf, -inf, ...]로 초기화
3335
self.cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')] * (beam_size - 1)).to(self.device)]
3436
# 각 빔이 현재 EOS에 도달했는지 여부
3537
# 1 if it is done else 0
@@ -123,7 +125,7 @@ def collect_result(self, y_hat, prev_status):
123125
cumulative_prob = y_hat + cumulative_prob.view(-1, 1, 1).expand(self.beam_size, 1, output_size)
124126
# |cumulative_prob| = (beam_size, 1, output_size)
125127

126-
# cumulative_prob를 (beam_size * output_size,)로
128+
# cumulative_prob를 (output_size * beam_size,)로
127129
# flatten 해준후 확률이 높은 순으로 정렬
128130
# top_indice에는 원래 정렬되기전 index가 유지됨
129131
top_log_prob, top_indice = cumulative_prob.view(-1).sort(descending=True)
@@ -164,7 +166,7 @@ def collect_result(self, y_hat, prev_status):
164166
def get_n_best(self, n=1, length_penalty=.2):
165167
'''
166168
이때까지의 Beam Board를 찾아보며,
167-
가장 확률 값이 높았던 N개의 문장 추출
169+
가장 누적 확률 값이 높았던 N개의 문장 추출
168170
'''
169171
sentences, probs, founds = [], [], []
170172

src/12_transformer/modules/transformer.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -472,39 +472,72 @@ def search(self, x, is_greedy=True, max_length=255):
472472
mask_dec = mask.unsqueeze(1)
473473
# |mask_enc| = (batch_size, n, n)
474474
# |mask_dec| = (batch_size, 1, n)
475+
# 학습때는 모든 타임스텝을 한번에 보내서 (bs, m, n)이지만
476+
# 추론시에는 다음 단어를 미리 모르므로, (bs, 1, n)
477+
# 그냥 이 디코더 마스크를 계속 재활용해서 사용함
475478

476479
z = self.emb_dropout(self._position_encoding(self.emb_enc(x)))
477480
z, _ = self.encoder(z, mask_enc)
478481
# |z| = (batch_size, n, hidden_size)
479482

480-
# Fill a vector, which has 'batch_size' dimension, with BOS value.
483+
# 각 배치마다 현재 타입스텝 예측 단어를 저장하는 공간
484+
# 처음에는 (batch_size, 1)로 전부 BOS로 채워져있음
481485
y_t_1 = x.new(batch_size, 1).zero_() + data_loader.BOS
482486
# |y_t_1| = (batch_size, 1)
487+
488+
# 현재 각 배치가 디코딩이 진행중인 것인지 True or False
489+
# 합쳐서 1이상이면 최소 한개라도 진행중이므로 속행하게 됨
483490
is_decoding = x.new_ones(batch_size, 1).bool()
484491

492+
'''
493+
# seq2seq는 이전 맨 마지막 hidden state만 갖고 있었으면 되었는데
494+
# transformer는 어텐션을 때리기 위해,
495+
# 각 계층마다 모든 타임스텝의 hidden_state를 참고해야 함
496+
# 따라서 처음에 decorder의 계층 수 + 1 (입력단 포함)의 None으로 가득찬 리스트 생성
497+
# prevs: 각 계층마다 hidden 값을 저장하기 위한 공간
498+
prev:
499+
[
500+
1계층: (batch_size, 1 -> N, hidden_size),
501+
2계층: (batch_size, 1 -> N, hidden_size),
502+
3계층: (batch_size, 1 -> N, hidden_size),
503+
...
504+
]
505+
첫 번째 타임스텝의 (1, 2, 3) 계층을 돌면서 prev를 박고...
506+
두 번째 타임스텝의 (1, 2, 3) 계층을 돌면서 prev를 박고...
507+
'''
485508
prevs = [None for _ in range(len(self.decoder._modules) + 1)]
486509
y_hats, indice = [], []
487-
# Repeat a loop while sum of 'is_decoding' flag is bigger than 0,
488-
# or current time-step is smaller than maximum length.
510+
489511
while is_decoding.sum() > 0 and len(indice) < max_length:
490-
# Unlike training procedure,
491-
# take the last time-step's output during the inference.
512+
# 각 배치의 한 타임스텝에 대한 히든 사이즈
492513
h_t = self.emb_dropout(
493514
self._position_encoding(self.emb_dec(y_t_1), init_pos=len(indice))
494515
)
495516
# |h_t| = (batch_size, 1, hidden_size))
517+
518+
# 맨 처음은 예외처리로써, 첫 계층 prev([0]) 정보들은
519+
# 디코더에 들어가기 전에 처리후, 저장해줌
520+
# 왜냐하면 디코더를 들어가야 하기 때문
496521
if prevs[0] is None:
522+
# None일 경우(초기상태), 그대로 바꿔주고
497523
prevs[0] = h_t
498524
else:
525+
# 유효한 텐서값일 경우, 해당 텐서 밑에 그대로 붙여줌
499526
prevs[0] = torch.cat([prevs[0], h_t], dim=1)
500527

528+
# Decoder Block를 하나하나 뽑아서 반복문 수행
501529
for layer_index, block in enumerate(self.decoder._modules.values()):
530+
# 현재 계층의 prev_status만 가져옴
502531
prev = prevs[layer_index]
503532
# |prev| = (batch_size, len(y_hats), hidden_size)
504533

534+
# 맨 처음 들어갈때는, 이전 타임스텝이 없음
535+
# 그래서 prevs[0]도 h_t가 들어 있음(h_t == prev)
536+
# 즉 규격에 맞춰주기 위해서 prev를 넣어준 셈(결국 완전 셀프 어텐션)
505537
h_t, _, _, _, _ = block(h_t, z, mask_dec, prev, None)
506538
# |h_t| = (batch_size, 1, hidden_size)
507539

540+
# 이번에 나온 결과(h_t)를 다음 계층의 Prev로써 쓸 수 있도록 가져옴
508541
if prevs[layer_index + 1] is None:
509542
prevs[layer_index + 1] = h_t
510543
else:
@@ -520,6 +553,8 @@ def search(self, x, is_greedy=True, max_length=255):
520553
else: # Random sampling
521554
y_t_1 = torch.multinomial(y_hat_t.exp().view(x.size(0), -1), 1)
522555
# Put PAD if the sample is done.
556+
# is_decoding이 False인 곳, 즉 EOS로 문장이 끝난 곳은
557+
# y_t_1에 <PAD>를 덮어씌워줌
523558
y_t_1 = y_t_1.masked_fill_(
524559
~is_decoding,
525560
data_loader.PAD,
@@ -534,8 +569,9 @@ def search(self, x, is_greedy=True, max_length=255):
534569
y_hats = torch.cat(y_hats, dim=1)
535570
indice = torch.cat(indice, dim=-1)
536571
# |y_hats| = (batch_size, m, output_size)
572+
# 각 단어의 확률 값
537573
# |indice| = (batch_size, m)
538-
574+
# 확률값을 토대로 구한 최종 단어 인덱스 모음
539575
return y_hats, indice
540576

541577
#@profile
@@ -628,7 +664,7 @@ def batch_beam_search(
628664
# |fab_input| = (current_batch_size, 1,)
629665
# |fab_z| = (current_batch_size, n, hidden_size)
630666
# |fab_mask| = (current_batch_size, 1, n)
631-
# |fab_prevs[i]| = (current_batch_size, length, hidden_size)
667+
# |fab_prevs[i]| = (cur rent_batch_size, length, hidden_size)
632668
# len(fab_prevs) = n_dec_layers + 1
633669

634670
# Unlike training procedure,

0 commit comments

Comments
 (0)