1515
1616"""Pretrain utilities."""
1717
18- import deepspeed
19- from megatron .utils import report_memory , flops_calculator
20- from megatron .schedules import forward_backward_pipelining_with_interleaving
21- from megatron .schedules import forward_backward_pipelining_without_interleaving
22- from megatron .schedules import forward_backward_no_pipelining
23- from megatron .utils import calc_params_l2_norm
24- from megatron .data .data_samplers import build_pretraining_data_loader
25- from megatron .utils import unwrap_model
26- from megatron .utils import check_adlr_autoresume_termination
27- from megatron .model import DistributedDataParallel as LocalDDP
28- from megatron .learning_rates import AnnealingLR
29- from megatron .initialize import write_args_to_tensorboard
30- from megatron .initialize import initialize_megatron
31- from megatron .optimizer import get_megatron_optimizer
32- from megatron .model import Float16Module
33- from megatron .checkpointing import save_checkpoint
34- from megatron .checkpointing import load_checkpoint
35- from megatron import print_rank_last
36- from megatron import print_rank_0
37- from megatron import mpu
38- from megatron import update_num_microbatches
39- from megatron import is_last_rank
40- from megatron import get_num_microbatches
41- from megatron import get_current_global_batch_size
42- from megatron import get_tensorboard_writer
43- from megatron import get_timers
44- from megatron import get_args
45- from torch .nn .parallel .distributed import DistributedDataParallel as torchDDP
46- import torch
4718from datetime import datetime
4819import math
4920import sys
5021import time
5122# The earliest we can measure the start time.
5223_TRAIN_START_TIME = time .time ()
5324
25+ import torch
26+ from torch .nn .parallel .distributed import DistributedDataParallel as torchDDP
27+
28+ from megatron import get_args
29+ from megatron import get_timers
30+ from megatron import get_tensorboard_writer
31+ from megatron import get_current_global_batch_size
32+ from megatron import get_num_microbatches
33+ from megatron import is_last_rank
34+ from megatron import update_num_microbatches
35+ from megatron import mpu
36+ from megatron import print_rank_0
37+ from megatron import print_rank_last
38+ from megatron .checkpointing import load_checkpoint
39+ from megatron .checkpointing import save_checkpoint
40+ from megatron .model import Float16Module
41+ from megatron .optimizer import get_megatron_optimizer
42+ from megatron .initialize import initialize_megatron
43+ from megatron .initialize import write_args_to_tensorboard
44+ from megatron .learning_rates import AnnealingLR
45+ from megatron .model import DistributedDataParallel as LocalDDP
46+ from megatron .utils import check_adlr_autoresume_termination
47+ from megatron .utils import unwrap_model
48+ from megatron .data .data_samplers import build_pretraining_data_loader
49+ from megatron .utils import calc_params_l2_norm
50+ from megatron .schedules import forward_backward_no_pipelining
51+ from megatron .schedules import forward_backward_pipelining_without_interleaving
52+ from megatron .schedules import forward_backward_pipelining_with_interleaving
53+ from megatron .utils import report_memory , flops_calculator
54+
55+ import deepspeed
56+
5457
5558def print_datetime (string ):
5659 """Note that this call will sync across all ranks."""
@@ -159,7 +162,6 @@ def pretrain(train_valid_test_dataset_provider,
159162 test_data_iterator , model ,
160163 0 , True )
161164
162-
163165def update_train_iters (args ):
164166
165167 # For iteration-based training, we don't need to do anything
@@ -184,7 +186,7 @@ def update_train_iters(args):
184186 # Constant phase
185187 # Note that we throw away any partial last batch.
186188 iterations += (args .train_samples - consumed_samples ) // \
187- args .global_batch_size
189+ args .global_batch_size
188190 args .train_iters = iterations
189191
190192 print_rank_0 ('setting training iterations to {}' .format (args .train_iters ))
@@ -216,6 +218,7 @@ def get_model(model_provider_func):
216218 post_process = post_process
217219 )
218220
221+
219222 if not isinstance (model , list ):
220223 model = [model ]
221224
@@ -231,10 +234,10 @@ def get_model(model_provider_func):
231234 if mpu .get_data_parallel_rank () == 0 :
232235 print (' > number of parameters on (tensor, pipeline) '
233236 'model parallel rank ({}, {}): {}' .format (
234- mpu .get_tensor_model_parallel_rank (),
235- mpu .get_pipeline_model_parallel_rank (),
236- sum ([sum ([p .ds_numel if hasattr (p , 'ds_id' ) else p .nelement () for p in model_module .parameters ()])
237- for model_module in model ])), flush = True )
237+ mpu .get_tensor_model_parallel_rank (),
238+ mpu .get_pipeline_model_parallel_rank (),
239+ sum ([sum ([p .ds_numel if hasattr (p ,'ds_id' ) else p .nelement () for p in model_module .parameters ()])
240+ for model_module in model ])), flush = True )
238241
239242 if args .deepspeed :
240243 return model
@@ -358,7 +361,7 @@ def setup_model_and_optimizer(model_provider_func):
358361
359362 # get model without FP16 and/or TorchDDP wrappers
360363 if args .iteration == 0 and len (unwrapped_model ) == 1 \
361- and hasattr (unwrapped_model [0 ], 'init_state_dict_from_bert' ):
364+ and hasattr (unwrapped_model [0 ], 'init_state_dict_from_bert' ):
362365 print_rank_0 ("Initializing ICT from pretrained BERT model" )
363366 unwrapped_model [0 ].init_state_dict_from_bert ()
364367 if args .fp16 :
@@ -379,7 +382,7 @@ def train_step(forward_step_func, data_iterator,
379382 skipped_iter = 0
380383 grad_norm = 0.
381384 num_zeros_in_grad = 0
382- return {'lm loss' : loss }, skipped_iter , grad_norm , num_zeros_in_grad
385+ return {'lm loss' : loss }, skipped_iter , grad_norm , num_zeros_in_grad
383386
384387 # Set grad to zero.
385388 if not args .deepspeed :
@@ -439,8 +442,8 @@ def train_step(forward_step_func, data_iterator,
439442 timers ('optimizer' ).start ()
440443 if args .deepspeed :
441444 increment = get_num_microbatches () * \
442- args .micro_batch_size * \
443- args .data_parallel_size
445+ args .micro_batch_size * \
446+ args .data_parallel_size
444447 model [0 ].step (lr_kwargs = {'increment' : increment })
445448 update_successful = model [0 ].was_step_applied ()
446449 else :
@@ -455,8 +458,8 @@ def train_step(forward_step_func, data_iterator,
455458 else :
456459 if update_successful :
457460 increment = get_num_microbatches () * \
458- args .micro_batch_size * \
459- args .data_parallel_size
461+ args .micro_batch_size * \
462+ args .data_parallel_size
460463 lr_scheduler .step (increment = increment )
461464 skipped_iter = 0
462465 else :
@@ -504,8 +507,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
504507 else :
505508 value = loss_dict [key ].float ().sum ().item ()
506509 is_nan = value == float ('inf' ) or \
507- value == - float ('inf' ) or \
508- value != value
510+ value == - float ('inf' ) or \
511+ value != value
509512 got_nan = got_nan or is_nan
510513 total_loss_dict [nan_iters_key ] = total_loss_dict .get (
511514 nan_iters_key , 0 ) + int (got_nan )
@@ -539,10 +542,10 @@ def add_to_logging(name):
539542 get_num_microbatches ()
540543
541544 total_iterations = total_loss_dict [advanced_iters_key ] + \
542- total_loss_dict [skipped_iters_key ]
545+ total_loss_dict [skipped_iters_key ]
543546
544547 # Tensorboard values.
545- if writer and (iteration % args .tensorboard_log_interval == 0 ) and \
548+ if writer and (iteration % args .tensorboard_log_interval == 0 ) and \
546549 is_last_rank ():
547550 if args .log_learning_rate_to_tensorboard :
548551 writer .add_scalar ('learning-rate' , learning_rate , iteration )
@@ -553,7 +556,7 @@ def add_to_logging(name):
553556 writer .add_scalar ('batch-size vs samples' , batch_size ,
554557 args .consumed_train_samples )
555558 for key in loss_dict :
556- writer .add_scalar (key , loss_dict [key ], iteration )
559+ writer .add_scalar (key , loss_dict [key ], iteration )
557560 writer .add_scalar (key + ' vs samples' , loss_dict [key ],
558561 args .consumed_train_samples )
559562 if args .log_loss_scale_to_tensorboard :
@@ -595,7 +598,7 @@ def add_to_logging(name):
595598 if key not in [advanced_iters_key , skipped_iters_key ,
596599 nan_iters_key ]:
597600 avg = total_loss_dict [key ].item () / \
598- float (max (1 , total_loss_dict [advanced_iters_key ]))
601+ float (max (1 , total_loss_dict [advanced_iters_key ]))
599602 if avg > 0.0 :
600603 log_string += ' {}: {:.6E} |' .format (key , avg )
601604 total_loss_dict [key ] = torch .cuda .FloatTensor ([0.0 ])
@@ -663,10 +666,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
663666 if args .deepspeed :
664667 # inform deepspeed of any batch size changes
665668 global_batch_size = mpu .get_data_parallel_world_size () * \
666- args .micro_batch_size * \
667- get_num_microbatches ()
669+ args .micro_batch_size * \
670+ get_num_microbatches ()
668671 model [0 ].set_train_batch_size (global_batch_size )
669672
673+
670674 loss_dict , skipped_iter , grad_norm , num_zeros_in_grad = \
671675 train_step (forward_step_func ,
672676 train_data_iterator ,
@@ -675,8 +679,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
675679 lr_scheduler )
676680 iteration += 1
677681 args .consumed_train_samples += mpu .get_data_parallel_world_size () * \
678- args .micro_batch_size * \
679- get_num_microbatches ()
682+ args .micro_batch_size * \
683+ get_num_microbatches ()
680684
681685 # Logging.
682686 if args .deepspeed :
@@ -739,6 +743,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
739743 print_datetime ('exiting program at iteration {}' .format (iteration ))
740744 sys .exit ()
741745
746+
742747 return iteration
743748
744749
@@ -767,17 +772,17 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
767772 forward_backward_func = forward_backward_pipelining_without_interleaving
768773 else :
769774 forward_backward_func = forward_backward_no_pipelining
770-
775+
771776 if args .deepspeed :
772777 # DeepSpeed uses eval_batch() and already aggregates losses.
773778 assert isinstance (model , list ) and len (model ) == 1
774779 loss = model [0 ].eval_batch (data_iterator )
775- loss_dicts = [{'lm loss' : loss }] * get_num_microbatches ()
780+ loss_dicts = [{'lm loss' : loss }] * get_num_microbatches ()
776781 else :
777782 loss_dicts = forward_backward_func (
778783 forward_step_func , data_iterator , model , optimizer = None ,
779784 timers = None , forward_only = True )
780-
785+
781786 if mpu .is_pipeline_last_stage (ignore_virtual = True ):
782787 # Reduce across processes.
783788 for loss_dict in loss_dicts :
@@ -786,8 +791,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
786791 key , torch .cuda .FloatTensor ([0.0 ])) + loss_dict [key ]
787792
788793 args .consumed_valid_samples += mpu .get_data_parallel_world_size () \
789- * args .micro_batch_size \
790- * get_num_microbatches ()
794+ * args .micro_batch_size \
795+ * get_num_microbatches ()
791796 # Move model back to the train mode.
792797 for model_module in model :
793798 model_module .train ()
@@ -797,7 +802,6 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
797802
798803 return total_loss_dict
799804
800-
801805def evaluate_and_print_results (prefix , forward_step_func ,
802806 data_iterator , model ,
803807 iteration , verbose = False ):
@@ -835,7 +839,6 @@ def cyclic_iter(iter):
835839 for x in iter :
836840 yield x
837841
838-
839842def build_train_valid_test_data_iterators (
840843 build_train_valid_test_datasets_provider ):
841844 """XXX"""
@@ -865,7 +868,7 @@ def build_train_valid_test_data_iterators(
865868 else :
866869 train_samples = args .train_iters * args .global_batch_size
867870 eval_iters = (args .train_iters // args .eval_interval + 1 ) * \
868- args .eval_iters
871+ args .eval_iters
869872 test_iters = args .eval_iters
870873 train_val_test_num_samples = [train_samples ,
871874 eval_iters * args .global_batch_size ,
@@ -904,25 +907,26 @@ def build_train_valid_test_data_iterators(
904907 args .do_valid = flags [1 ].item ()
905908 args .do_test = flags [2 ].item ()
906909
910+
907911 # Build iterators.
908912 dl_type = args .dataloader_type
909913 assert dl_type in ['single' , 'cyclic' ]
910914
911915 if train_dataloader is not None :
912916 train_data_iterator = iter (train_dataloader ) if dl_type == 'single' \
913- else iter (cyclic_iter (train_dataloader ))
917+ else iter (cyclic_iter (train_dataloader ))
914918 else :
915919 train_data_iterator = None
916920
917921 if valid_dataloader is not None :
918922 valid_data_iterator = iter (valid_dataloader ) if dl_type == 'single' \
919- else iter (cyclic_iter (valid_dataloader ))
923+ else iter (cyclic_iter (valid_dataloader ))
920924 else :
921925 valid_data_iterator = None
922926
923927 if test_dataloader is not None :
924928 test_data_iterator = iter (test_dataloader ) if dl_type == 'single' \
925- else iter (cyclic_iter (test_dataloader ))
929+ else iter (cyclic_iter (test_dataloader ))
926930 else :
927931 test_data_iterator = None
928932
0 commit comments