@@ -472,39 +472,72 @@ def search(self, x, is_greedy=True, max_length=255):
472472 mask_dec = mask .unsqueeze (1 )
473473 # |mask_enc| = (batch_size, n, n)
474474 # |mask_dec| = (batch_size, 1, n)
475+ # 학습때는 모든 타임스텝을 한번에 보내서 (bs, m, n)이지만
476+ # 추론시에는 다음 단어를 미리 모르므로, (bs, 1, n)
477+ # 그냥 이 디코더 마스크를 계속 재활용해서 사용함
475478
476479 z = self .emb_dropout (self ._position_encoding (self .emb_enc (x )))
477480 z , _ = self .encoder (z , mask_enc )
478481 # |z| = (batch_size, n, hidden_size)
479482
480- # Fill a vector, which has 'batch_size' dimension, with BOS value.
483+ # 각 배치마다 현재 타입스텝 예측 단어를 저장하는 공간
484+ # 처음에는 (batch_size, 1)로 전부 BOS로 채워져있음
481485 y_t_1 = x .new (batch_size , 1 ).zero_ () + data_loader .BOS
482486 # |y_t_1| = (batch_size, 1)
487+
488+ # 현재 각 배치가 디코딩이 진행중인 것인지 True or False
489+ # 합쳐서 1이상이면 최소 한개라도 진행중이므로 속행하게 됨
483490 is_decoding = x .new_ones (batch_size , 1 ).bool ()
484491
492+ '''
493+ # seq2seq는 이전 맨 마지막 hidden state만 갖고 있었으면 되었는데
494+ # transformer는 어텐션을 때리기 위해,
495+ # 각 계층마다 모든 타임스텝의 hidden_state를 참고해야 함
496+ # 따라서 처음에 decorder의 계층 수 + 1 (입력단 포함)의 None으로 가득찬 리스트 생성
497+ # prevs: 각 계층마다 hidden 값을 저장하기 위한 공간
498+ prev:
499+ [
500+ 1계층: (batch_size, 1 -> N, hidden_size),
501+ 2계층: (batch_size, 1 -> N, hidden_size),
502+ 3계층: (batch_size, 1 -> N, hidden_size),
503+ ...
504+ ]
505+ 첫 번째 타임스텝의 (1, 2, 3) 계층을 돌면서 prev를 박고...
506+ 두 번째 타임스텝의 (1, 2, 3) 계층을 돌면서 prev를 박고...
507+ '''
485508 prevs = [None for _ in range (len (self .decoder ._modules ) + 1 )]
486509 y_hats , indice = [], []
487- # Repeat a loop while sum of 'is_decoding' flag is bigger than 0,
488- # or current time-step is smaller than maximum length.
510+
489511 while is_decoding .sum () > 0 and len (indice ) < max_length :
490- # Unlike training procedure,
491- # take the last time-step's output during the inference.
512+ # 각 배치의 한 타임스텝에 대한 히든 사이즈
492513 h_t = self .emb_dropout (
493514 self ._position_encoding (self .emb_dec (y_t_1 ), init_pos = len (indice ))
494515 )
495516 # |h_t| = (batch_size, 1, hidden_size))
517+
518+ # 맨 처음은 예외처리로써, 첫 계층 prev([0]) 정보들은
519+ # 디코더에 들어가기 전에 처리후, 저장해줌
520+ # 왜냐하면 디코더를 들어가야 하기 때문
496521 if prevs [0 ] is None :
522+ # None일 경우(초기상태), 그대로 바꿔주고
497523 prevs [0 ] = h_t
498524 else :
525+ # 유효한 텐서값일 경우, 해당 텐서 밑에 그대로 붙여줌
499526 prevs [0 ] = torch .cat ([prevs [0 ], h_t ], dim = 1 )
500527
528+ # Decoder Block를 하나하나 뽑아서 반복문 수행
501529 for layer_index , block in enumerate (self .decoder ._modules .values ()):
530+ # 현재 계층의 prev_status만 가져옴
502531 prev = prevs [layer_index ]
503532 # |prev| = (batch_size, len(y_hats), hidden_size)
504533
534+ # 맨 처음 들어갈때는, 이전 타임스텝이 없음
535+ # 그래서 prevs[0]도 h_t가 들어 있음(h_t == prev)
536+ # 즉 규격에 맞춰주기 위해서 prev를 넣어준 셈(결국 완전 셀프 어텐션)
505537 h_t , _ , _ , _ , _ = block (h_t , z , mask_dec , prev , None )
506538 # |h_t| = (batch_size, 1, hidden_size)
507539
540+ # 이번에 나온 결과(h_t)를 다음 계층의 Prev로써 쓸 수 있도록 가져옴
508541 if prevs [layer_index + 1 ] is None :
509542 prevs [layer_index + 1 ] = h_t
510543 else :
@@ -520,6 +553,8 @@ def search(self, x, is_greedy=True, max_length=255):
520553 else : # Random sampling
521554 y_t_1 = torch .multinomial (y_hat_t .exp ().view (x .size (0 ), - 1 ), 1 )
522555 # Put PAD if the sample is done.
556+ # is_decoding이 False인 곳, 즉 EOS로 문장이 끝난 곳은
557+ # y_t_1에 <PAD>를 덮어씌워줌
523558 y_t_1 = y_t_1 .masked_fill_ (
524559 ~ is_decoding ,
525560 data_loader .PAD ,
@@ -534,8 +569,9 @@ def search(self, x, is_greedy=True, max_length=255):
534569 y_hats = torch .cat (y_hats , dim = 1 )
535570 indice = torch .cat (indice , dim = - 1 )
536571 # |y_hats| = (batch_size, m, output_size)
572+ # 각 단어의 확률 값
537573 # |indice| = (batch_size, m)
538-
574+ # 확률값을 토대로 구한 최종 단어 인덱스 모음
539575 return y_hats , indice
540576
541577 #@profile
@@ -628,7 +664,7 @@ def batch_beam_search(
628664 # |fab_input| = (current_batch_size, 1,)
629665 # |fab_z| = (current_batch_size, n, hidden_size)
630666 # |fab_mask| = (current_batch_size, 1, n)
631- # |fab_prevs[i]| = (current_batch_size , length, hidden_size)
667+ # |fab_prevs[i]| = (cur rent_batch_size , length, hidden_size)
632668 # len(fab_prevs) = n_dec_layers + 1
633669
634670 # Unlike training procedure,
0 commit comments