Skip to content

Commit 186f48d

Browse files
committed
Create trainer_amp.py
1 parent f6531c1 commit 186f48d

File tree

1 file changed

+252
-0
lines changed

1 file changed

+252
-0
lines changed
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import numpy as np
2+
3+
import torch
4+
import torch.nn.utils as torch_utils
5+
from torch.cuda.amp import autocast
6+
from torch.cuda.amp import GradScaler
7+
8+
from ignite.engine import Engine
9+
from ignite.engine import Events
10+
from ignite.metrics import RunningAverage
11+
from ignite.contrib.handlers.tqdm_logger import ProgressBar
12+
13+
from simple_nmt.utils import get_grad_norm, get_parameter_norm
14+
15+
16+
VERBOSE_SILENT = 0
17+
VERBOSE_EPOCH_WISE = 1
18+
VERBOSE_BATCH_WISE = 2
19+
20+
21+
class AmpEngine(Engine):
22+
23+
def __init__(self, func, model, crit, optimizer, lr_scheduler, config):
24+
self.model = model
25+
self.crit = crit
26+
self.optimizer = optimizer
27+
self.lr_scheduler = lr_scheduler
28+
self.config = config
29+
30+
super().__init__(func)
31+
32+
self.best_loss = np.inf
33+
self.scaler = GradScaler()
34+
35+
@staticmethod
36+
#@profile
37+
def train(engine, mini_batch):
38+
# You have to reset the gradients of all model parameters
39+
# before to take another step in gradient descent.
40+
engine.model.train()
41+
if engine.state.iteration % engine.config.iteration_per_update == 1 or \
42+
engine.config.iteration_per_update == 1:
43+
if engine.state.iteration > 1:
44+
engine.optimizer.zero_grad()
45+
46+
device = next(engine.model.parameters()).device
47+
mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
48+
mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])
49+
50+
# Raw target variable has both BOS and EOS token.
51+
# The output of sequence-to-sequence does not have BOS token.
52+
# Thus, remove BOS token for reference.
53+
x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
54+
# |x| = (batch_size, length)
55+
# |y| = (batch_size, length)
56+
57+
# autocast로 공간효율적으로 학습 실행
58+
with autocast(not engine.config.off_autocast):
59+
# with autocast(not engine.config.off_autocast):
60+
# y_hat = (batch_size, length_m, output_size)
61+
# 입력 tgt의 경우, 맨뒤에 EOS를 토큰을 제거
62+
y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
63+
# |y_hat| = (batch_size, length, output_size)
64+
65+
'''
66+
loss값 연산을 위해 다음과 같이 텐서 모양 정리
67+
모든 문장의 각 단어를 순서대로 배치했다고 보면됨
68+
변경 전(3D):
69+
y_hat = (batch_size, length_m, output_size)
70+
y = (batch_size, length_m)
71+
변경 후(2D):
72+
y_hat = (batch_size * length_m, output_size)
73+
y = (batch_size * length_m)
74+
'''
75+
loss = engine.crit(
76+
y_hat.contiguous().view(-1, y_hat.size(-1)),
77+
y.contiguous().view(-1)
78+
)
79+
'''
80+
div(y.size(0)): loss를 구한후, batch_size만큼 나눠준 후
81+
div(engine.config.iteration_per_update):
82+
Gradient Accumulation을 위해 미리 나눠줌
83+
즉, backward_target이 진짜 적용시킬 loss 값이라 보면 됨
84+
'''
85+
backward_target = loss.div(y.size(0)).div(engine.config.iteration_per_update)
86+
87+
if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
88+
engine.scaler.scale(backward_target).backward()
89+
else:
90+
backward_target.backward()
91+
92+
word_count = int(mini_batch.tgt[1].sum())
93+
p_norm = float(get_parameter_norm(engine.model.parameters()))
94+
g_norm = float(get_grad_norm(engine.model.parameters()))
95+
96+
if engine.state.iteration % engine.config.iteration_per_update == 0 and \
97+
engine.state.iteration > 0:
98+
'''
99+
Gradient Clipping
100+
시퀸스의 time_step이 길수록, gradient가 매우 커질수도 있음
101+
g_norm이 너무 커서 많이 움직이는 걸 막기 위해 사용
102+
- 단, Adam을 쓰면 큰 필요는 없다고 함 ㅇㅇ
103+
'''
104+
torch_utils.clip_grad_norm_(
105+
engine.model.parameters(),
106+
engine.config.max_grad_norm,
107+
)
108+
# Take a step of gradient descent.
109+
if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
110+
# GPU를 사용할 경우, 기존 optim.step() 대신에 scaler로 step 수행
111+
engine.scaler.step(engine.optimizer)
112+
engine.scaler.update()
113+
else:
114+
engine.optimizer.step()
115+
116+
loss = float(loss / word_count)
117+
ppl = np.exp(loss)
118+
119+
return {
120+
'loss': loss,
121+
'ppl': ppl,
122+
'|param|': p_norm if not np.isnan(p_norm) and not np.isinf(p_norm) else 0.,
123+
'|g_param|': g_norm if not np.isnan(g_norm) and not np.isinf(g_norm) else 0.,
124+
}
125+
126+
@staticmethod
127+
def validate(engine, mini_batch):
128+
engine.model.eval()
129+
130+
with torch.no_grad():
131+
device = next(engine.model.parameters()).device
132+
mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
133+
mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])
134+
135+
x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
136+
# |x| = (batch_size, length)
137+
# |y| = (batch_size, length)
138+
139+
with autocast(not engine.config.off_autocast):
140+
y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
141+
# |y_hat| = (batch_size, n_classes)
142+
loss = engine.crit(
143+
y_hat.contiguous().view(-1, y_hat.size(-1)),
144+
y.contiguous().view(-1),
145+
)
146+
147+
word_count = int(mini_batch.tgt[1].sum())
148+
loss = float(loss / word_count)
149+
ppl = np.exp(loss)
150+
151+
return {
152+
'loss': loss,
153+
'ppl': ppl,
154+
}
155+
156+
@staticmethod
157+
def attach(
158+
train_engine, validation_engine,
159+
training_metric_names = ['loss', 'ppl', '|param|', '|g_param|'],
160+
validation_metric_names = ['loss', 'ppl'],
161+
verbose=VERBOSE_BATCH_WISE,
162+
):
163+
# Attaching would be repaeted for serveral metrics.
164+
# Thus, we can reduce the repeated codes by using this function.
165+
def attach_running_average(engine, metric_name):
166+
RunningAverage(output_transform=lambda x: x[metric_name]).attach(
167+
engine,
168+
metric_name,
169+
)
170+
171+
for metric_name in training_metric_names:
172+
attach_running_average(train_engine, metric_name)
173+
174+
if verbose >= VERBOSE_BATCH_WISE:
175+
pbar = ProgressBar(bar_format=None, ncols=120)
176+
pbar.attach(train_engine, training_metric_names)
177+
178+
if verbose >= VERBOSE_EPOCH_WISE:
179+
@train_engine.on(Events.EPOCH_COMPLETED)
180+
def print_train_logs(engine):
181+
avg_p_norm = engine.state.metrics['|param|']
182+
avg_g_norm = engine.state.metrics['|g_param|']
183+
avg_loss = engine.state.metrics['loss']
184+
185+
print('Epoch {} - |param|={:.2e} |g_param|={:.2e} loss={:.4e} ppl={:.2f}'.format(
186+
engine.state.epoch,
187+
avg_p_norm,
188+
avg_g_norm,
189+
avg_loss,
190+
np.exp(avg_loss),
191+
))
192+
193+
for metric_name in validation_metric_names:
194+
attach_running_average(validation_engine, metric_name)
195+
196+
if verbose >= VERBOSE_BATCH_WISE:
197+
pbar = ProgressBar(bar_format=None, ncols=120)
198+
pbar.attach(validation_engine, validation_metric_names)
199+
200+
if verbose >= VERBOSE_EPOCH_WISE:
201+
@validation_engine.on(Events.EPOCH_COMPLETED)
202+
def print_valid_logs(engine):
203+
avg_loss = engine.state.metrics['loss']
204+
205+
print('Validation - loss={:.4e} ppl={:.2f} best_loss={:.4e} best_ppl={:.2f}'.format(
206+
avg_loss,
207+
np.exp(avg_loss),
208+
engine.best_loss,
209+
np.exp(engine.best_loss),
210+
))
211+
212+
@staticmethod
213+
def resume_training(engine, resume_epoch):
214+
engine.state.iteration = (resume_epoch - 1) * len(engine.state.dataloader)
215+
engine.state.epoch = (resume_epoch - 1)
216+
217+
@staticmethod
218+
def check_best(engine):
219+
loss = float(engine.state.metrics['loss'])
220+
if loss <= engine.best_loss:
221+
engine.best_loss = loss
222+
223+
@staticmethod
224+
def save_model(engine, train_engine, config, src_vocab, tgt_vocab):
225+
avg_train_loss = train_engine.state.metrics['loss']
226+
avg_valid_loss = engine.state.metrics['loss']
227+
228+
# Set a filename for model of last epoch.
229+
# We need to put every information to filename, as much as possible.
230+
model_fn = config.model_fn.split('.')
231+
232+
model_fn = model_fn[:-1] + ['%02d' % train_engine.state.epoch,
233+
'%.2f-%.2f' % (avg_train_loss,
234+
np.exp(avg_train_loss)
235+
),
236+
'%.2f-%.2f' % (avg_valid_loss,
237+
np.exp(avg_valid_loss)
238+
)
239+
] + [model_fn[-1]]
240+
241+
model_fn = '.'.join(model_fn)
242+
243+
# Unlike other tasks, we need to save current model, not best model.
244+
torch.save(
245+
{
246+
'model': engine.model.state_dict(),
247+
'opt': train_engine.optimizer.state_dict(),
248+
'config': config,
249+
'src_vocab': src_vocab,
250+
'tgt_vocab': tgt_vocab,
251+
}, model_fn
252+
)

0 commit comments

Comments
 (0)