Skip to content

Commit 190072b

Browse files
committed
seq2seq 오류 체킹
1 parent 9db2d2c commit 190072b

File tree

6 files changed

+48
-69
lines changed

6 files changed

+48
-69
lines changed

src/11_seq2seq/data_preparation/detokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
#-*- coding:utf-8 -*-
12
import sys
3+
sys.stdin.reconfigure(encoding='utf-8')
4+
25

36
if __name__ == "__main__":
47
for line in sys.stdin:

src/11_seq2seq/modules/data_loader.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ class DataLoader:
77

88
def __init__(
99
self,
10-
train_fn,
11-
valid_fn,
12-
exts,
10+
train_fn=None,
11+
valid_fn=None,
12+
exts=None,
1313
batch_size=64,
1414
device='cpu',
1515
max_vocab=9999999,
@@ -39,40 +39,41 @@ def __init__(
3939
eos_token=None,
4040
)
4141

42-
train = TranslationDataset(
43-
path=train_fn,
44-
exts=exts,
45-
fields=[('src', self.src), ('tgt', self.tgt)],
46-
max_length=max_length
47-
)
48-
valid = TranslationDataset(
49-
path=valid_fn,
50-
exts=exts,
51-
fields=[('src', self.src), ('tgt', self.tgt)],
52-
max_length=max_length,
53-
)
54-
55-
self.train_iter = data.BucketIterator(
56-
train,
57-
batch_size=batch_size,
58-
device='cuda:%d' % device if device >= 0 else 'cpu',
59-
shuffle=shuffle,
60-
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
61-
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
62-
sort_within_batch=True,
63-
)
64-
self.valid_iter = data.BucketIterator(
65-
valid,
66-
batch_size=batch_size,
67-
device='cuda:%d' % device if device >= 0 else 'cpu',
68-
shuffle=False,
69-
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
70-
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
71-
sort_within_batch=True,
72-
)
73-
74-
self.src.build_vocab(train, max_size=max_vocab)
75-
self.tgt.build_vocab(train, max_size=max_vocab)
42+
if train_fn is not None and valid_fn is not None and exts is not None:
43+
train = TranslationDataset(
44+
path=train_fn,
45+
exts=exts,
46+
fields=[('src', self.src), ('tgt', self.tgt)],
47+
max_length=max_length
48+
)
49+
valid = TranslationDataset(
50+
path=valid_fn,
51+
exts=exts,
52+
fields=[('src', self.src), ('tgt', self.tgt)],
53+
max_length=max_length,
54+
)
55+
56+
self.train_iter = data.BucketIterator(
57+
train,
58+
batch_size=batch_size,
59+
device='cuda:%d' % device if device >= 0 else 'cpu',
60+
shuffle=shuffle,
61+
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
62+
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
63+
sort_within_batch=True,
64+
)
65+
self.valid_iter = data.BucketIterator(
66+
valid,
67+
batch_size=batch_size,
68+
device='cuda:%d' % device if device >= 0 else 'cpu',
69+
shuffle=False,
70+
# 비슷한 길이끼리 미니 배치를 만들도록 정렬
71+
sort_key=lambda x: len(x.tgt) + (max_length * len(x.src)),
72+
sort_within_batch=True,
73+
)
74+
75+
self.src.build_vocab(train, max_size=max_vocab)
76+
self.tgt.build_vocab(train, max_size=max_vocab)
7677

7778
def load_vocab(self, src_vocab, tgt_vocab):
7879
self.src.vocab = src_vocab

src/11_seq2seq/modules/seq2seq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
1717
hidden_size=int(hidden_size / 2),
1818
num_layers=n_layers,
1919
dropout=dropout_p,
20+
bidirectional=True,
2021
batch_first=True,
21-
bidirectional=True
2222
)
2323

2424
def forward(self, emb):
@@ -264,6 +264,9 @@ def forward(self, src, tgt):
264264
mask = self.generate_mask(x, x_length)
265265
else:
266266
x = src
267+
268+
if isinstance(tgt, tuple):
269+
tgt = tgt[0]
267270

268271
#---------Encoder Step---------#
269272
# emb_src = (batch_size, length_n, word_vec_size)

src/11_seq2seq/modules/trainer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,7 @@ def train(engine, mini_batch):
113113

114114
# 현재 batch 내에 모든 토큰 수
115115
word_count = int(mini_batch.tgt[1].sum())
116-
117-
# 점점 커짐
118116
p_norm = float(get_parameter_norm(engine.model.parameters()))
119-
# 점점 작아짐
120117
g_norm = float(get_grad_norm(engine.model.parameters()))
121118

122119
# Gradient Accumulation 여부, 맞아 떨어진다면 step까지 수행, 아니면 스킵

src/11_seq2seq/train.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -179,31 +179,6 @@ def get_optimizer(model, config):
179179
return optimizer
180180

181181

182-
def get_scheduler(optimizer, config):
183-
'''
184-
# learing_rate 스케쥴러
185-
학습도중, learning_rate를 조정하기 위한 역할
186-
ex) epoch 9까지 lr:1로 하다가, 10부터 0.5, 0.25, 0.125 식으로 낮춰라 등
187-
아래 코드에서는 lr_decay_start번째부터 config.lr_step를 곱하며 낮춰감
188-
# 그러나 이부분은 사용하지 않음 X
189-
'''
190-
if config.lr_step > 0:
191-
lr_scheduler = optim.lr_scheduler.MultiStepLR(
192-
optimizer,
193-
milestones=[i for i in range(
194-
max(0, config.lr_decay_start - 1),
195-
(config.init_epoch - 1) + config.n_epochs,
196-
config.lr_step
197-
)],
198-
gamma=config.lr_gamma,
199-
last_epoch=config.init_epoch - 1 if config.init_epoch > 1 else -1,
200-
)
201-
else:
202-
lr_scheduler = None
203-
204-
return lr_scheduler
205-
206-
207182
def main(config, model_weight=None, opt_weight=None):
208183
def print_config(config):
209184
pp = pprint.PrettyPrinter(indent=4)
@@ -232,10 +207,10 @@ def print_config(config):
232207

233208
optimizer = get_optimizer(model, config)
234209

235-
if opt_weight and (config.use_adam or config.use_radam):
210+
if opt_weight:
236211
optimizer.load_state_dict(opt_weight)
237212

238-
lr_scheduler = get_scheduler(optimizer, config)
213+
lr_scheduler = None
239214

240215
if config.verbose >= 2:
241216
print(model)

src/11_seq2seq/translate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from modules.data_loader import DataLoader
77
import modules.data_loader as data_loader
8-
from modules.models.seq2seq import Seq2Seq
8+
from modules.seq2seq import Seq2Seq
99

1010

1111
def define_argparser():

0 commit comments

Comments
 (0)