11import numpy as np
2+
23import torch
34import torch .nn .utils as torch_utils
45# from torch.cuda.amp import autocast
56# from torch.cuda.amp import GradScaler
7+
68from ignite .engine import Engine
79from ignite .engine import Events
810from ignite .metrics import RunningAverage
911from ignite .contrib .handlers .tqdm_logger import ProgressBar
12+
1013from modules .utils import get_grad_norm , get_parameter_norm
1114
15+
1216VERBOSE_SILENT = 0
1317VERBOSE_EPOCH_WISE = 1
1418VERBOSE_BATCH_WISE = 2
@@ -22,116 +26,66 @@ def __init__(self, func, model, crit, optimizer, lr_scheduler, config):
2226 self .optimizer = optimizer
2327 self .lr_scheduler = lr_scheduler
2428 self .config = config
29+
2530 super ().__init__ (func )
2631
2732 self .best_loss = np .inf
2833 #self.scaler = GradScaler()
2934
3035 @staticmethod
36+ #@profile
3137 def train (engine , mini_batch ):
38+ # You have to reset the gradients of all model parameters
39+ # before to take another step in gradient descent.
3240 engine .model .train ()
33-
34- '''
35- Gradient Accumulation
36- - 기계 번역의 경우, batch_size가 256 정도가 적당
37- 즉 batch_size 크기 자체도 성능에 영향을 끼침
38- 하지만 GPU 성능에 따라 하지 못할 수도 있음
39-
40- 속도는 몰라도, 성능을 보존시켜 주기 위해 일부러
41- N번 정도의 iteration을 건너뛰어서 원하는 성능을 유지시킴
42-
43- 1. engine.state.iteration % engine.config.iteration_per_update == 1
44- - 현재 iter가 per_update로 나눠서 나머지가 1일때마다 zero_grad 수행
45- 2. engine.config.iteration_per_update == 1
46- - 그냥 통상적인 경우 매번 zero_grad 시키기
47-
48- '''
4941 if engine .state .iteration % engine .config .iteration_per_update == 1 or \
50- engine .config .iteration_per_update == 1 :
51- if engine .state .iteration > 1 :
42+ engine .config .iteration_per_update == 1 :
43+ if engine .state .iteration > 1 :
5244 engine .optimizer .zero_grad ()
5345
54- # 모델의 첫번째 파라미터가 config임
5546 device = next (engine .model .parameters ()).device
56- '''
57- src와 tgt는 각각 (실제 문장 데이터, 각 문장의 길이 정보) tuple 형태
58- - torchText에서 애초에 저렇게 제공됨
59- 그 중에서 실제 문장 데이터만 GPU 메모리로 전송
60- '''
6147 mini_batch .src = (mini_batch .src [0 ].to (device ), mini_batch .src [1 ])
62- mini_batch .tgt = (mini_batch .src [0 ].to (device ), mini_batch .tgt [1 ])
63-
64- '''
65- 맨 처음 Input으로 x가 들어감
66- 최종 Output과의 검증을 위해 y가 들어감
67- x의 경우, 그냥 그대로 넣어주면 됨(BOS EOS 들어가도 노상관)
68- y의 경우,
69- - 각 문장의 길이 정보는 버림
70- - 또한 실제 문장에서도 맨처음 BOS 토큰을 제거
71- (왜냐하면 예측은 BOS 다음 단어부터 수행하기 때문)
72-
73- x = (batch_size, length_n)
74- y = (batch_size, length_m)
75- '''
48+ mini_batch .tgt = (mini_batch .tgt [0 ].to (device ), mini_batch .tgt [1 ])
49+
50+ # Raw target variable has both BOS and EOS token.
51+ # The output of sequence-to-sequence does not have BOS token.
52+ # Thus, remove BOS token for reference.
7653 x , y = mini_batch .src , mini_batch .tgt [0 ][:, 1 :]
54+ # |x| = (batch_size, length)
55+ # |y| = (batch_size, length)
7756
78- #-------------------------#
79- # autocast로 공간효율적으로 학습 실행
80- # with autocast(not engine.config.off_autocast):
81- # y_hat = (batch_size, length_m, output_size)
82- # 입력 tgt의 경우, 맨뒤에 EOS를 토큰을 제거
57+ #with autocast(not engine.config.off_autocast):
58+ # Take feed-forward
59+ # Similar as before, the input of decoder does not have EOS token.
60+ # Thus, remove EOS token for decoder input.
8361 y_hat = engine .model (x , mini_batch .tgt [0 ][:, :- 1 ])
62+ # |y_hat| = (batch_size, length, output_size)
8463
85- '''
86- loss값 연산을 위해 다음과 같이 텐서 모양 정리
87- 모든 문장의 각 단어를 순서대로 배치했다고 보면됨
88- 변경 전(3D):
89- y_hat = (batch_size, length_m, output_size)
90- y = (batch_size, length_m)
91- 변경 후(2D):
92- y_hat = (batch_size * length_m, output_size)
93- y = (batch_size * length_m)
94- '''
9564 loss = engine .crit (
9665 y_hat .contiguous ().view (- 1 , y_hat .size (- 1 )),
9766 y .contiguous ().view (- 1 )
9867 )
99- '''
100- div(y.size(0)): loss를 구한후, batch_size만큼 나눠준 후
101- div(engine.config.iteration_per_update):
102- Gradient Accumulation을 위해 미리 나눠줌
103- 즉, backward_target이 진짜 적용시킬 loss 값이라 보면 됨
104- '''
10568 backward_target = loss .div (y .size (0 )).div (engine .config .iteration_per_update )
106- #-------------------------#
10769
108- # autocast가 켜져 있는 경우, scale 작업 후에, backward
10970 # if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
11071 # engine.scaler.scale(backward_target).backward()
11172 # else:
11273 backward_target .backward ()
11374
114- # 현재 batch 내에 모든 토큰 수
11575 word_count = int (mini_batch .tgt [1 ].sum ())
11676 p_norm = float (get_parameter_norm (engine .model .parameters ()))
11777 g_norm = float (get_grad_norm (engine .model .parameters ()))
11878
119- # Gradient Accumulation 여부, 맞아 떨어진다면 step까지 수행, 아니면 스킵
12079 if engine .state .iteration % engine .config .iteration_per_update == 0 and \
121- engine .state .iteration > 0 :
122- '''
123- Gradient Clipping
124- 시퀸스의 time_step이 길수록, gradient가 매우 커질수도 있음
125- g_norm이 너무 커서 많이 움직이는 걸 막기 위해 사용
126- - 단, Adam을 쓰면 큰 필요는 없다고 함 ㅇㅇ
127- '''
80+ engine .state .iteration > 0 :
81+ # In orther to avoid gradient exploding, we apply gradient clipping.
12882 torch_utils .clip_grad_norm_ (
12983 engine .model .parameters (),
13084 engine .config .max_grad_norm ,
13185 )
132-
86+ # Take a step of gradient descent.
13387 # if engine.config.gpu_id >= 0 and not engine.config.off_autocast:
134- # # GPU를 사용할 경우, 기존 optim. step() 대신에 scaler로 step 수행
88+ # # Use scaler instead of engine.optimizer. step() if using GPU.
13589 # engine.scaler.step(engine.optimizer)
13690 # engine.scaler.update()
13791 # else:
@@ -156,17 +110,18 @@ def validate(engine, mini_batch):
156110 mini_batch .src = (mini_batch .src [0 ].to (device ), mini_batch .src [1 ])
157111 mini_batch .tgt = (mini_batch .tgt [0 ].to (device ), mini_batch .tgt [1 ])
158112
159- # x = (batch_size, length_n)
160- # y = (batch_size, length_m)
161113 x , y = mini_batch .src , mini_batch .tgt [0 ][:, 1 :]
114+ # |x| = (batch_size, length)
115+ # |y| = (batch_size, length)
162116
163117 #with autocast(not engine.config.off_autocast):
164118 y_hat = engine .model (x , mini_batch .tgt [0 ][:, :- 1 ])
119+ # |y_hat| = (batch_size, n_classes)
165120 loss = engine .crit (
166121 y_hat .contiguous ().view (- 1 , y_hat .size (- 1 )),
167122 y .contiguous ().view (- 1 ),
168123 )
169-
124+
170125 word_count = int (mini_batch .tgt [1 ].sum ())
171126 loss = float (loss / word_count )
172127 ppl = np .exp (loss )
@@ -179,18 +134,18 @@ def validate(engine, mini_batch):
179134 @staticmethod
180135 def attach (
181136 train_engine , validation_engine ,
182- training_metric_names = ['loss' , 'ppl' , '|param|' , '|g_param|' ],
183- validation_metric_names = ['loss' , 'ppl' ],
137+ training_metric_names = ['loss' , 'ppl' , '|param|' , '|g_param|' ],
138+ validation_metric_names = ['loss' , 'ppl' ],
184139 verbose = VERBOSE_BATCH_WISE ,
185140 ):
186- # 현재 상황 보고 및 출력 함수
141+ # Attaching would be repaeted for serveral metrics.
142+ # Thus, we can reduce the repeated codes by using this function.
187143 def attach_running_average (engine , metric_name ):
188144 RunningAverage (output_transform = lambda x : x [metric_name ]).attach (
189145 engine ,
190146 metric_name ,
191147 )
192148
193- '''Train Attach Process'''
194149 for metric_name in training_metric_names :
195150 attach_running_average (train_engine , metric_name )
196151
@@ -213,7 +168,6 @@ def print_train_logs(engine):
213168 np .exp (avg_loss ),
214169 ))
215170
216- '''Validation Attach Process'''
217171 for metric_name in validation_metric_names :
218172 attach_running_average (validation_engine , metric_name )
219173
@@ -249,7 +203,6 @@ def save_model(engine, train_engine, config, src_vocab, tgt_vocab):
249203 avg_train_loss = train_engine .state .metrics ['loss' ]
250204 avg_valid_loss = engine .state .metrics ['loss' ]
251205
252- # 주의!, best_model이 아닌 모든 에포크의 모델 저장
253206 # Set a filename for model of last epoch.
254207 # We need to put every information to filename, as much as possible.
255208 model_fn = config .model_fn .split ('.' )
@@ -275,7 +228,7 @@ def save_model(engine, train_engine, config, src_vocab, tgt_vocab):
275228 'tgt_vocab' : tgt_vocab ,
276229 }, model_fn
277230 )
278-
231+
279232
280233class Trainer ():
281234
@@ -291,6 +244,7 @@ def train(
291244 n_epochs ,
292245 lr_scheduler = None
293246 ):
247+ # Declare train and validation engine with necessary objects.
294248 train_engine = self .target_engine_class (
295249 self .target_engine_class .train ,
296250 model ,
@@ -308,31 +262,41 @@ def train(
308262 config = self .config
309263 )
310264
265+ # Do necessary attach procedure to train & validation engine.
266+ # Progress bar and metric would be attached.
311267 self .target_engine_class .attach (
312268 train_engine ,
313269 validation_engine ,
314270 verbose = self .config .verbose
315271 )
316272
273+ # After every train epoch, run 1 validation epoch.
274+ # Also, apply LR scheduler if it is necessary.
317275 def run_validation (engine , validation_engine , valid_loader ):
318276 validation_engine .run (valid_loader , max_epochs = 1 )
277+
319278 if engine .lr_scheduler is not None :
320279 engine .lr_scheduler .step ()
321280
281+ # Attach above call-back function.
322282 train_engine .add_event_handler (
323283 Events .EPOCH_COMPLETED ,
324284 run_validation ,
325285 validation_engine ,
326286 valid_loader
327287 )
288+ # Attach other call-back function for initiation of the training.
328289 train_engine .add_event_handler (
329290 Events .STARTED ,
330291 self .target_engine_class .resume_training ,
331292 self .config .init_epoch ,
332293 )
294+
295+ # Attach validation loss check procedure for every end of validation epoch.
333296 validation_engine .add_event_handler (
334297 Events .EPOCH_COMPLETED , self .target_engine_class .check_best
335298 )
299+ # Attach model save procedure for every end of validation epoch.
336300 validation_engine .add_event_handler (
337301 Events .EPOCH_COMPLETED ,
338302 self .target_engine_class .save_model ,
@@ -342,7 +306,7 @@ def run_validation(engine, validation_engine, valid_loader):
342306 tgt_vocab ,
343307 )
344308
345- # Start training
309+ # Start training.
346310 train_engine .run (train_loader , max_epochs = n_epochs )
347311
348- return model
312+ return model
0 commit comments