Skip to content

Commit 9064c93

Browse files
committed
transformer 구현
1 parent 1d8f862 commit 9064c93

File tree

9 files changed

+1359
-158
lines changed

9 files changed

+1359
-158
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import sys
2+
import os.path
3+
import torch
4+
5+
from train import define_argparser
6+
from train import main
7+
8+
9+
def overwrite_config(config, prev_config):
10+
# This method provides a compatibility for new or missing arguments.
11+
for prev_key in vars(prev_config).keys():
12+
if not prev_key in vars(config).keys():
13+
# No such argument in current config. Ignore that value.
14+
print('WARNING!!! Argument "--%s" is not found in current argument parser.\tIgnore saved value:' % prev_key,
15+
vars(prev_config)[prev_key])
16+
17+
for key in vars(config).keys():
18+
if not key in vars(prev_config).keys():
19+
# No such argument in saved file. Use current value.
20+
print('WARNING!!! Argument "--%s" is not found in saved model.\tUse current value:' % key,
21+
vars(config)[key])
22+
elif vars(config)[key] != vars(prev_config)[key]:
23+
if '--%s' % key in sys.argv:
24+
# User changed argument value at this execution.
25+
print('WARNING!!! You changed value for argument "--%s".\tUse current value:' % key,
26+
vars(config)[key])
27+
else:
28+
# User didn't changed at this execution, but current config and saved config is different.
29+
# This may caused by user's intension at last execution.
30+
# Load old value, and replace current value.
31+
vars(config)[key] = vars(prev_config)[key]
32+
33+
return config
34+
35+
36+
def continue_main(config, main):
37+
# If the model exists, load model and configuration to continue the training.
38+
if os.path.isfile(config.load_fn):
39+
saved_data = torch.load(config.load_fn, map_location='cpu')
40+
41+
prev_config = saved_data['config']
42+
config = overwrite_config(config, prev_config)
43+
44+
model_weight = saved_data['model']
45+
opt_weight = saved_data['opt']
46+
47+
main(config, model_weight=model_weight, opt_weight=opt_weight)
48+
else:
49+
print('Cannot find file %s' % config.load_fn)
50+
51+
52+
if __name__ == '__main__':
53+
config = define_argparser(is_continue=True)
54+
continue_main(config, main)

src/12_transformer/detokenizer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#-*- coding:utf-8 -*-
2+
import sys
3+
sys.stdin.reconfigure(encoding='utf-8')
4+
5+
6+
if __name__ == "__main__":
7+
for line in sys.stdin:
8+
if line.strip() != "":
9+
if '▁▁' in line:
10+
line = line.strip().replace(' ', '').replace('▁▁', ' ').replace('▁', '').strip()
11+
else:
12+
line = line.strip().replace(' ', '').replace('▁', ' ').strip()
13+
14+
sys.stdout.write(line + '\n')
15+
else:
16+
sys.stdout.write('\n')
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import os
2+
from torchtext import data
3+
4+
PAD, BOS, EOS = 1, 2, 3
5+
6+
class DataLoader:
7+
8+
def __init__(
9+
self,
10+
train_fn=None,
11+
valid_fn=None,
12+
exts=None,
13+
batch_size=64,
14+
device='cpu',
15+
max_vocab=9999999,
16+
max_length=255,
17+
fix_length=None,
18+
use_bos=True,
19+
use_eos=True,
20+
shuffle=True,
21+
):
22+
23+
self.src = data.Field(
24+
sequential=True,
25+
use_vocab=True,
26+
batch_first=True,
27+
include_lengths=True,
28+
fix_length=fix_length,
29+
init_token=None,
30+
eos_token=None,
31+
)
32+
self.tgt = data.Field(
33+
sequential=True,
34+
use_vocab=True,
35+
batch_first=True,
36+
include_lengths=True,
37+
fix_length=fix_length,
38+
init_token='<BOS>',
39+
eos_token='<EOS>',
40+
)
41+
42+
if train_fn is not None and valid_fn is not None and exts is not None:
43+
train = TranslationDataset(
44+
path=train_fn,
45+
exts=exts,
46+
fields=[('src', self.src), ('tgt', self.tgt)],
47+
max_length=max_length
48+
)
49+
valid = TranslationDataset(
50+
path=valid_fn,
51+
exts=exts,
52+
fields=[('src', self.src), ('tgt', self.tgt)],
53+
max_length=max_length,
54+
)
55+
56+
self.train_iter = data.BucketIterator(
57+
train,
58+
batch_size=batch_size,
59+
device='cuda:%d' % device if device >= 0 else 'cpu',
60+
shuffle=shuffle,
61+
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
62+
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
63+
sort_within_batch=True,
64+
)
65+
self.valid_iter = data.BucketIterator(
66+
valid,
67+
batch_size=batch_size,
68+
device='cuda:%d' % device if device >= 0 else 'cpu',
69+
shuffle=False,
70+
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
71+
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
72+
sort_within_batch=True,
73+
)
74+
75+
self.src.build_vocab(train, max_size=max_vocab)
76+
self.tgt.build_vocab(train, max_size=max_vocab)
77+
78+
def load_vocab(self, src_vocab, tgt_vocab):
79+
self.src.vocab = src_vocab
80+
self.tgt.vocab = tgt_vocab
81+
82+
83+
class TranslationDataset(data.Dataset):
84+
85+
def __init__(self, path, exts, fields, max_length=None, **kwargs):
86+
"""Create a TranslationDataset given paths and fields.
87+
88+
MAX LENGTH로 각 데이터를 자르기 위한 예외처리 오버라이딩
89+
90+
Arguments:
91+
path: Common prefix of paths to the data files for both languages.
92+
exts: A tuple containing the extension to path for each language.
93+
fields: A tuple containing the fields that will be used for data
94+
in each language.
95+
Remaining keyword arguments: Passed to the constructor of
96+
data.Dataset.
97+
"""
98+
if not isinstance(fields[0], (tuple, list)):
99+
fields = [('src', fields[0]), ('trg', fields[1])]
100+
101+
if not path.endswith('.'):
102+
path += '.'
103+
104+
src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
105+
106+
examples = []
107+
with open(src_path, encoding='utf-8') as src_file, open(trg_path, encoding='utf-8') as trg_file:
108+
for src_line, trg_line in zip(src_file, trg_file):
109+
src_line, trg_line = src_line.strip(), trg_line.strip()
110+
if max_length and max_length < max(len(src_line.split()), len(trg_line.split())):
111+
continue
112+
if src_line != '' and trg_line != '':
113+
examples += [data.Example.fromlist([src_line, trg_line], fields)]
114+
115+
super().__init__(examples, fields, **kwargs)
116+
117+
@staticmethod
118+
def sort_key(ex):
119+
return data.interleave_keys(len(ex.src), len(ex.trg))
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from operator import itemgetter
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
import modules.data_loader as data_loader
7+
8+
LENGTH_PENALTY = .2
9+
MIN_LENGTH = 5
10+
11+
12+
class SingleBeamSearchBoard():
13+
14+
def __init__(
15+
self,
16+
device,
17+
prev_status_config,
18+
beam_size=5,
19+
max_length=255,
20+
):
21+
self.beam_size = beam_size
22+
self.max_length = max_length
23+
24+
# To put data to same device.
25+
self.device = device
26+
# Inferred word index for each time-step. For now, initialized with initial time-step.
27+
self.word_indice = [torch.LongTensor(beam_size).zero_().to(self.device) + data_loader.BOS]
28+
# Beam index for selected word index, at each time-step.
29+
self.beam_indice = [torch.LongTensor(beam_size).zero_().to(self.device) - 1]
30+
# Cumulative log-probability for each beam.
31+
self.cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')] * (beam_size - 1)).to(self.device)]
32+
# 1 if it is done else 0
33+
self.masks = [torch.BoolTensor(beam_size).zero_().to(self.device)]
34+
35+
# We don't need to remember every time-step of hidden states:
36+
# prev_hidden, prev_cell, prev_h_t_tilde
37+
# What we need is remember just last one.
38+
39+
self.prev_status = {}
40+
self.batch_dims = {}
41+
for prev_status_name, each_config in prev_status_config.items():
42+
init_status = each_config['init_status']
43+
batch_dim_index = each_config['batch_dim_index']
44+
if init_status is not None:
45+
self.prev_status[prev_status_name] = torch.cat([init_status] * beam_size,
46+
dim=batch_dim_index)
47+
else:
48+
self.prev_status[prev_status_name] = None
49+
self.batch_dims[prev_status_name] = batch_dim_index
50+
51+
self.current_time_step = 0
52+
self.done_cnt = 0
53+
54+
def get_length_penalty(
55+
self,
56+
length,
57+
alpha=LENGTH_PENALTY,
58+
min_length=MIN_LENGTH,
59+
):
60+
# Calculate length-penalty,
61+
# because shorter sentence usually have bigger probability.
62+
# In fact, we represent this as log-probability, which is negative value.
63+
# Thus, we need to multiply bigger penalty for shorter one.
64+
p = ((min_length + 1) / (min_length + length))**alpha
65+
66+
return p
67+
68+
def is_done(self):
69+
# Return 1, if we had EOS more than 'beam_size'-times.
70+
if self.done_cnt >= self.beam_size:
71+
return 1
72+
return 0
73+
74+
def get_batch(self):
75+
y_hat = self.word_indice[-1].unsqueeze(-1)
76+
# |y_hat| = (beam_size, 1)
77+
# if model != transformer:
78+
# |hidden| = |cell| = (n_layers, beam_size, hidden_size)
79+
# |h_t_tilde| = (beam_size, 1, hidden_size) or None
80+
# else:
81+
# |prev_state_i| = (beam_size, length, hidden_size),
82+
# where i is an index of layer.
83+
return y_hat, self.prev_status
84+
85+
#@profile
86+
def collect_result(self, y_hat, prev_status):
87+
# |y_hat| = (beam_size, 1, output_size)
88+
# prev_status is a dict, which has following keys:
89+
# if model != transformer:
90+
# |hidden| = |cell| = (n_layers, beam_size, hidden_size)
91+
# |h_t_tilde| = (beam_size, 1, hidden_size)
92+
# else:
93+
# |prev_state_i| = (beam_size, length, hidden_size),
94+
# where i is an index of layer.
95+
output_size = y_hat.size(-1)
96+
97+
self.current_time_step += 1
98+
99+
# Calculate cumulative log-probability.
100+
# First, fill -inf value to last cumulative probability, if the beam is already finished.
101+
# Second, expand -inf filled cumulative probability to fit to 'y_hat'.
102+
# (beam_size) --> (beam_size, 1, 1) --> (beam_size, 1, output_size)
103+
# Third, add expanded cumulative probability to 'y_hat'
104+
cumulative_prob = self.cumulative_probs[-1].masked_fill_(self.masks[-1], -float('inf'))
105+
cumulative_prob = y_hat + cumulative_prob.view(-1, 1, 1).expand(self.beam_size, 1, output_size)
106+
# |cumulative_prob| = (beam_size, 1, output_size)
107+
108+
# Now, we have new top log-probability and its index.
109+
# We picked top index as many as 'beam_size'.
110+
# Be aware that we picked top-k from whole batch through 'view(-1)'.
111+
112+
# Following lines are using torch.topk, which is slower than torch.sort.
113+
# top_log_prob, top_indice = torch.topk(
114+
# cumulative_prob.view(-1), # (beam_size * output_size,)
115+
# self.beam_size,
116+
# dim=-1,
117+
# )
118+
119+
# Following lines are using torch.sort, instead of using torch.topk.
120+
top_log_prob, top_indice = cumulative_prob.view(-1).sort(descending=True)
121+
top_log_prob, top_indice = top_log_prob[:self.beam_size], top_indice[:self.beam_size]
122+
123+
# |top_log_prob| = (beam_size,)
124+
# |top_indice| = (beam_size,)
125+
126+
# Because we picked from whole batch, original word index should be calculated again.
127+
self.word_indice += [top_indice.fmod(output_size)]
128+
# Also, we can get an index of beam, which has top-k log-probability search result.
129+
self.beam_indice += [top_indice.div(float(output_size)).long()]
130+
131+
# Add results to history boards.
132+
self.cumulative_probs += [top_log_prob]
133+
self.masks += [torch.eq(self.word_indice[-1], data_loader.EOS)] # Set finish mask if we got EOS.
134+
# Calculate a number of finished beams.
135+
self.done_cnt += self.masks[-1].float().sum()
136+
137+
# In beam search procedure, we only need to memorize latest status.
138+
# For seq2seq, it would be lastest hidden and cell state, and h_t_tilde.
139+
# The problem is hidden(or cell) state and h_t_tilde has different dimension order.
140+
# In other words, a dimension for batch index is different.
141+
# Therefore self.batch_dims stores the dimension index for batch index.
142+
# For transformer, lastest status is each layer's decoder output from the biginning.
143+
# Unlike seq2seq, transformer has to memorize every previous output for attention operation.
144+
for prev_status_name, prev_status in prev_status.items():
145+
self.prev_status[prev_status_name] = torch.index_select(
146+
prev_status,
147+
dim=self.batch_dims[prev_status_name],
148+
index=self.beam_indice[-1]
149+
).contiguous()
150+
151+
def get_n_best(self, n=1, length_penalty=.2):
152+
sentences, probs, founds = [], [], []
153+
154+
for t in range(len(self.word_indice)): # for each time-step,
155+
for b in range(self.beam_size): # for each beam,
156+
if self.masks[t][b] == 1: # if we had EOS on this time-step and beam,
157+
# Take a record of penaltified log-proability.
158+
probs += [self.cumulative_probs[t][b] * self.get_length_penalty(t, alpha=length_penalty)]
159+
founds += [(t, b)]
160+
161+
# Also, collect log-probability from last time-step, for the case of EOS is not shown.
162+
for b in range(self.beam_size):
163+
if self.cumulative_probs[-1][b] != -float('inf'): # If this beam does not have EOS,
164+
if not (len(self.cumulative_probs) - 1, b) in founds:
165+
probs += [self.cumulative_probs[-1][b] * self.get_length_penalty(len(self.cumulative_probs),
166+
alpha=length_penalty)]
167+
founds += [(t, b)]
168+
169+
# Sort and take n-best.
170+
sorted_founds_with_probs = sorted(
171+
zip(founds, probs),
172+
key=itemgetter(1),
173+
reverse=True,
174+
)[:n]
175+
probs = []
176+
177+
for (end_index, b), prob in sorted_founds_with_probs:
178+
sentence = []
179+
180+
# Trace from the end.
181+
for t in range(end_index, 0, -1):
182+
sentence = [self.word_indice[t][b]] + sentence
183+
b = self.beam_indice[t][b]
184+
185+
sentences += [sentence]
186+
probs += [prob]
187+
188+
return sentences, probs

0 commit comments

Comments
 (0)