Skip to content

Commit 9c6c9be

Browse files
committed
seq2 test
1 parent 509127b commit 9c6c9be

File tree

3 files changed

+396
-85
lines changed

3 files changed

+396
-85
lines changed

src/11_seq2seq/modules/trainer.py

Lines changed: 48 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import numpy as np
2+
23
import torch
34
import torch.nn.utils as torch_utils
45
# from torch.cuda.amp import autocast
56
# from torch.cuda.amp import GradScaler
7+
68
from ignite.engine import Engine
79
from ignite.engine import Events
810
from ignite.metrics import RunningAverage
911
from ignite.contrib.handlers.tqdm_logger import ProgressBar
12+
1013
from modules.utils import get_grad_norm, get_parameter_norm
1114

15+
1216
VERBOSE_SILENT = 0
1317
VERBOSE_EPOCH_WISE = 1
1418
VERBOSE_BATCH_WISE = 2
@@ -22,116 +26,66 @@ def __init__(self, func, model, crit, optimizer, lr_scheduler, config):
2226
self.optimizer = optimizer
2327
self.lr_scheduler = lr_scheduler
2428
self.config = config
29+
2530
super().__init__(func)
2631

2732
self.best_loss = np.inf
2833
#self.scaler = GradScaler()
2934

3035
@staticmethod
36+
#@profile
3137
def train(engine, mini_batch):
38+
# You have to reset the gradients of all model parameters
39+
# before to take another step in gradient descent.
3240
engine.model.train()
33-
34-
'''
35-
Gradient Accumulation
36-
- 기계 번역의 경우, batch_size가 256 정도가 적당
37-
즉 batch_size 크기 자체도 성능에 영향을 끼침
38-
하지만 GPU 성능에 따라 하지 못할 수도 있음
39-
40-
속도는 몰라도, 성능을 보존시켜 주기 위해 일부러
41-
N번 정도의 iteration을 건너뛰어서 원하는 성능을 유지시킴
42-
43-
1. engine.state.iteration % engine.config.iteration_per_update == 1
44-
- 현재 iter가 per_update로 나눠서 나머지가 1일때마다 zero_grad 수행
45-
2. engine.config.iteration_per_update == 1
46-
- 그냥 통상적인 경우 매번 zero_grad 시키기
47-
48-
'''
4941
if engine.state.iteration % engine.config.iteration_per_update == 1 or \
50-
engine.config.iteration_per_update == 1:
51-
if engine.state.iteration > 1:
42+
engine.config.iteration_per_update == 1:
43+
if engine.state.iteration > 1:
5244
engine.optimizer.zero_grad()
5345

54-
# 모델의 첫번째 파라미터가 config임
5546
device = next(engine.model.parameters()).device
56-
'''
57-
src와 tgt는 각각 (실제 문장 데이터, 각 문장의 길이 정보) tuple 형태
58-
- torchText에서 애초에 저렇게 제공됨
59-
그 중에서 실제 문장 데이터만 GPU 메모리로 전송
60-
'''
6147
mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
62-
mini_batch.tgt = (mini_batch.src[0].to(device), mini_batch.tgt[1])
63-
64-
'''
65-
맨 처음 Input으로 x가 들어감
66-
최종 Output과의 검증을 위해 y가 들어감
67-
x의 경우, 그냥 그대로 넣어주면 됨(BOS EOS 들어가도 노상관)
68-
y의 경우,
69-
- 각 문장의 길이 정보는 버림
70-
- 또한 실제 문장에서도 맨처음 BOS 토큰을 제거
71-
(왜냐하면 예측은 BOS 다음 단어부터 수행하기 때문)
72-
73-
x = (batch_size, length_n)
74-
y = (batch_size, length_m)
75-
'''
48+
mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])
49+
50+
# Raw target variable has both BOS and EOS token.
51+
# The output of sequence-to-sequence does not have BOS token.
52+
# Thus, remove BOS token for reference.
7653
x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
54+
# |x| = (batch_size, length)
55+
# |y| = (batch_size, length)
7756

78-
#-------------------------#
79-
# autocast로 공간효율적으로 학습 실행
80-
# with autocast(not engine.config.off_autocast):
81-
# y_hat = (batch_size, length_m, output_size)
82-
# 입력 tgt의 경우, 맨뒤에 EOS를 토큰을 제거
57+
#with autocast(not engine.config.off_autocast):
58+
# Take feed-forward
59+
# Similar as before, the input of decoder does not have EOS token.
60+
# Thus, remove EOS token for decoder input.
8361
y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
62+
# |y_hat| = (batch_size, length, output_size)
8463

85-
'''
86-
loss값 연산을 위해 다음과 같이 텐서 모양 정리
87-
모든 문장의 각 단어를 순서대로 배치했다고 보면됨
88-
변경 전(3D):
89-
y_hat = (batch_size, length_m, output_size)
90-
y = (batch_size, length_m)
91-
변경 후(2D):
92-
y_hat = (batch_size * length_m, output_size)
93-
y = (batch_size * length_m)
94-
'''
9564
loss = engine.crit(
9665
y_hat.contiguous().view(-1, y_hat.size(-1)),
9766
y.contiguous().view(-1)
9867
)
99-
'''
100-
div(y.size(0)): loss를 구한후, batch_size만큼 나눠준 후
101-
div(engine.config.iteration_per_update):
102-
Gradient Accumulation을 위해 미리 나눠줌
103-
즉, backward_target이 진짜 적용시킬 loss 값이라 보면 됨
104-
'''
10568
backward_target = loss.div(y.size(0)).div(engine.config.iteration_per_update)
106-
#-------------------------#
10769

108-
# autocast가 켜져 있는 경우, scale 작업 후에, backward
10970
# if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
11071
# engine.scaler.scale(backward_target).backward()
11172
# else:
11273
backward_target.backward()
11374

114-
# 현재 batch 내에 모든 토큰 수
11575
word_count = int(mini_batch.tgt[1].sum())
11676
p_norm = float(get_parameter_norm(engine.model.parameters()))
11777
g_norm = float(get_grad_norm(engine.model.parameters()))
11878

119-
# Gradient Accumulation 여부, 맞아 떨어진다면 step까지 수행, 아니면 스킵
12079
if engine.state.iteration % engine.config.iteration_per_update == 0 and \
121-
engine.state.iteration > 0:
122-
'''
123-
Gradient Clipping
124-
시퀸스의 time_step이 길수록, gradient가 매우 커질수도 있음
125-
g_norm이 너무 커서 많이 움직이는 걸 막기 위해 사용
126-
- 단, Adam을 쓰면 큰 필요는 없다고 함 ㅇㅇ
127-
'''
80+
engine.state.iteration > 0:
81+
# In orther to avoid gradient exploding, we apply gradient clipping.
12882
torch_utils.clip_grad_norm_(
12983
engine.model.parameters(),
13084
engine.config.max_grad_norm,
13185
)
132-
86+
# Take a step of gradient descent.
13387
# if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
134-
# # GPU를 사용할 경우, 기존 optim.step() 대신에 scaler로 step 수행
88+
# # Use scaler instead of engine.optimizer.step() if using GPU.
13589
# engine.scaler.step(engine.optimizer)
13690
# engine.scaler.update()
13791
# else:
@@ -156,17 +110,18 @@ def validate(engine, mini_batch):
156110
mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
157111
mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])
158112

159-
# x = (batch_size, length_n)
160-
# y = (batch_size, length_m)
161113
x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
114+
# |x| = (batch_size, length)
115+
# |y| = (batch_size, length)
162116

163117
#with autocast(not engine.config.off_autocast):
164118
y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
119+
# |y_hat| = (batch_size, n_classes)
165120
loss = engine.crit(
166121
y_hat.contiguous().view(-1, y_hat.size(-1)),
167122
y.contiguous().view(-1),
168123
)
169-
124+
170125
word_count = int(mini_batch.tgt[1].sum())
171126
loss = float(loss / word_count)
172127
ppl = np.exp(loss)
@@ -179,18 +134,18 @@ def validate(engine, mini_batch):
179134
@staticmethod
180135
def attach(
181136
train_engine, validation_engine,
182-
training_metric_names=['loss', 'ppl', '|param|', '|g_param|'],
183-
validation_metric_names=['loss', 'ppl'],
137+
training_metric_names = ['loss', 'ppl', '|param|', '|g_param|'],
138+
validation_metric_names = ['loss', 'ppl'],
184139
verbose=VERBOSE_BATCH_WISE,
185140
):
186-
# 현재 상황 보고 및 출력 함수
141+
# Attaching would be repaeted for serveral metrics.
142+
# Thus, we can reduce the repeated codes by using this function.
187143
def attach_running_average(engine, metric_name):
188144
RunningAverage(output_transform=lambda x: x[metric_name]).attach(
189145
engine,
190146
metric_name,
191147
)
192148

193-
'''Train Attach Process'''
194149
for metric_name in training_metric_names:
195150
attach_running_average(train_engine, metric_name)
196151

@@ -213,7 +168,6 @@ def print_train_logs(engine):
213168
np.exp(avg_loss),
214169
))
215170

216-
'''Validation Attach Process'''
217171
for metric_name in validation_metric_names:
218172
attach_running_average(validation_engine, metric_name)
219173

@@ -249,7 +203,6 @@ def save_model(engine, train_engine, config, src_vocab, tgt_vocab):
249203
avg_train_loss = train_engine.state.metrics['loss']
250204
avg_valid_loss = engine.state.metrics['loss']
251205

252-
# 주의!, best_model이 아닌 모든 에포크의 모델 저장
253206
# Set a filename for model of last epoch.
254207
# We need to put every information to filename, as much as possible.
255208
model_fn = config.model_fn.split('.')
@@ -275,7 +228,7 @@ def save_model(engine, train_engine, config, src_vocab, tgt_vocab):
275228
'tgt_vocab': tgt_vocab,
276229
}, model_fn
277230
)
278-
231+
279232

280233
class Trainer():
281234

@@ -291,6 +244,7 @@ def train(
291244
n_epochs,
292245
lr_scheduler=None
293246
):
247+
# Declare train and validation engine with necessary objects.
294248
train_engine = self.target_engine_class(
295249
self.target_engine_class.train,
296250
model,
@@ -308,31 +262,41 @@ def train(
308262
config=self.config
309263
)
310264

265+
# Do necessary attach procedure to train & validation engine.
266+
# Progress bar and metric would be attached.
311267
self.target_engine_class.attach(
312268
train_engine,
313269
validation_engine,
314270
verbose=self.config.verbose
315271
)
316272

273+
# After every train epoch, run 1 validation epoch.
274+
# Also, apply LR scheduler if it is necessary.
317275
def run_validation(engine, validation_engine, valid_loader):
318276
validation_engine.run(valid_loader, max_epochs=1)
277+
319278
if engine.lr_scheduler is not None:
320279
engine.lr_scheduler.step()
321280

281+
# Attach above call-back function.
322282
train_engine.add_event_handler(
323283
Events.EPOCH_COMPLETED,
324284
run_validation,
325285
validation_engine,
326286
valid_loader
327287
)
288+
# Attach other call-back function for initiation of the training.
328289
train_engine.add_event_handler(
329290
Events.STARTED,
330291
self.target_engine_class.resume_training,
331292
self.config.init_epoch,
332293
)
294+
295+
# Attach validation loss check procedure for every end of validation epoch.
333296
validation_engine.add_event_handler(
334297
Events.EPOCH_COMPLETED, self.target_engine_class.check_best
335298
)
299+
# Attach model save procedure for every end of validation epoch.
336300
validation_engine.add_event_handler(
337301
Events.EPOCH_COMPLETED,
338302
self.target_engine_class.save_model,
@@ -342,7 +306,7 @@ def run_validation(engine, validation_engine, valid_loader):
342306
tgt_vocab,
343307
)
344308

345-
# Start training
309+
# Start training.
346310
train_engine.run(train_loader, max_epochs=n_epochs)
347311

348-
return model
312+
return model

0 commit comments

Comments
 (0)