Skip to content

Commit 3644a9d

Browse files
committed
Manual revert training.py to main verion
1 parent 7222a97 commit 3644a9d

File tree

1 file changed

+66
-62
lines changed

1 file changed

+66
-62
lines changed

megatron/training.py

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,45 @@
1515

1616
"""Pretrain utilities."""
1717

18-
import deepspeed
19-
from megatron.utils import report_memory, flops_calculator
20-
from megatron.schedules import forward_backward_pipelining_with_interleaving
21-
from megatron.schedules import forward_backward_pipelining_without_interleaving
22-
from megatron.schedules import forward_backward_no_pipelining
23-
from megatron.utils import calc_params_l2_norm
24-
from megatron.data.data_samplers import build_pretraining_data_loader
25-
from megatron.utils import unwrap_model
26-
from megatron.utils import check_adlr_autoresume_termination
27-
from megatron.model import DistributedDataParallel as LocalDDP
28-
from megatron.learning_rates import AnnealingLR
29-
from megatron.initialize import write_args_to_tensorboard
30-
from megatron.initialize import initialize_megatron
31-
from megatron.optimizer import get_megatron_optimizer
32-
from megatron.model import Float16Module
33-
from megatron.checkpointing import save_checkpoint
34-
from megatron.checkpointing import load_checkpoint
35-
from megatron import print_rank_last
36-
from megatron import print_rank_0
37-
from megatron import mpu
38-
from megatron import update_num_microbatches
39-
from megatron import is_last_rank
40-
from megatron import get_num_microbatches
41-
from megatron import get_current_global_batch_size
42-
from megatron import get_tensorboard_writer
43-
from megatron import get_timers
44-
from megatron import get_args
45-
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
46-
import torch
4718
from datetime import datetime
4819
import math
4920
import sys
5021
import time
5122
# The earliest we can measure the start time.
5223
_TRAIN_START_TIME = time.time()
5324

25+
import torch
26+
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
27+
28+
from megatron import get_args
29+
from megatron import get_timers
30+
from megatron import get_tensorboard_writer
31+
from megatron import get_current_global_batch_size
32+
from megatron import get_num_microbatches
33+
from megatron import is_last_rank
34+
from megatron import update_num_microbatches
35+
from megatron import mpu
36+
from megatron import print_rank_0
37+
from megatron import print_rank_last
38+
from megatron.checkpointing import load_checkpoint
39+
from megatron.checkpointing import save_checkpoint
40+
from megatron.model import Float16Module
41+
from megatron.optimizer import get_megatron_optimizer
42+
from megatron.initialize import initialize_megatron
43+
from megatron.initialize import write_args_to_tensorboard
44+
from megatron.learning_rates import AnnealingLR
45+
from megatron.model import DistributedDataParallel as LocalDDP
46+
from megatron.utils import check_adlr_autoresume_termination
47+
from megatron.utils import unwrap_model
48+
from megatron.data.data_samplers import build_pretraining_data_loader
49+
from megatron.utils import calc_params_l2_norm
50+
from megatron.schedules import forward_backward_no_pipelining
51+
from megatron.schedules import forward_backward_pipelining_without_interleaving
52+
from megatron.schedules import forward_backward_pipelining_with_interleaving
53+
from megatron.utils import report_memory, flops_calculator
54+
55+
import deepspeed
56+
5457

5558
def print_datetime(string):
5659
"""Note that this call will sync across all ranks."""
@@ -159,7 +162,6 @@ def pretrain(train_valid_test_dataset_provider,
159162
test_data_iterator, model,
160163
0, True)
161164

162-
163165
def update_train_iters(args):
164166

165167
# For iteration-based training, we don't need to do anything
@@ -184,7 +186,7 @@ def update_train_iters(args):
184186
# Constant phase
185187
# Note that we throw away any partial last batch.
186188
iterations += (args.train_samples - consumed_samples) // \
187-
args.global_batch_size
189+
args.global_batch_size
188190
args.train_iters = iterations
189191

190192
print_rank_0('setting training iterations to {}'.format(args.train_iters))
@@ -216,6 +218,7 @@ def get_model(model_provider_func):
216218
post_process=post_process
217219
)
218220

221+
219222
if not isinstance(model, list):
220223
model = [model]
221224

@@ -231,10 +234,10 @@ def get_model(model_provider_func):
231234
if mpu.get_data_parallel_rank() == 0:
232235
print(' > number of parameters on (tensor, pipeline) '
233236
'model parallel rank ({}, {}): {}'.format(
234-
mpu.get_tensor_model_parallel_rank(),
235-
mpu.get_pipeline_model_parallel_rank(),
236-
sum([sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() for p in model_module.parameters()])
237-
for model_module in model])), flush=True)
237+
mpu.get_tensor_model_parallel_rank(),
238+
mpu.get_pipeline_model_parallel_rank(),
239+
sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
240+
for model_module in model])), flush=True)
238241

239242
if args.deepspeed:
240243
return model
@@ -358,7 +361,7 @@ def setup_model_and_optimizer(model_provider_func):
358361

359362
# get model without FP16 and/or TorchDDP wrappers
360363
if args.iteration == 0 and len(unwrapped_model) == 1 \
361-
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
364+
and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
362365
print_rank_0("Initializing ICT from pretrained BERT model")
363366
unwrapped_model[0].init_state_dict_from_bert()
364367
if args.fp16:
@@ -379,7 +382,7 @@ def train_step(forward_step_func, data_iterator,
379382
skipped_iter = 0
380383
grad_norm = 0.
381384
num_zeros_in_grad = 0
382-
return {'lm loss': loss}, skipped_iter, grad_norm, num_zeros_in_grad
385+
return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad
383386

384387
# Set grad to zero.
385388
if not args.deepspeed:
@@ -439,8 +442,8 @@ def train_step(forward_step_func, data_iterator,
439442
timers('optimizer').start()
440443
if args.deepspeed:
441444
increment = get_num_microbatches() * \
442-
args.micro_batch_size * \
443-
args.data_parallel_size
445+
args.micro_batch_size * \
446+
args.data_parallel_size
444447
model[0].step(lr_kwargs={'increment': increment})
445448
update_successful = model[0].was_step_applied()
446449
else:
@@ -455,8 +458,8 @@ def train_step(forward_step_func, data_iterator,
455458
else:
456459
if update_successful:
457460
increment = get_num_microbatches() * \
458-
args.micro_batch_size * \
459-
args.data_parallel_size
461+
args.micro_batch_size * \
462+
args.data_parallel_size
460463
lr_scheduler.step(increment=increment)
461464
skipped_iter = 0
462465
else:
@@ -504,8 +507,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
504507
else:
505508
value = loss_dict[key].float().sum().item()
506509
is_nan = value == float('inf') or \
507-
value == -float('inf') or \
508-
value != value
510+
value == -float('inf') or \
511+
value != value
509512
got_nan = got_nan or is_nan
510513
total_loss_dict[nan_iters_key] = total_loss_dict.get(
511514
nan_iters_key, 0) + int(got_nan)
@@ -539,10 +542,10 @@ def add_to_logging(name):
539542
get_num_microbatches()
540543

541544
total_iterations = total_loss_dict[advanced_iters_key] + \
542-
total_loss_dict[skipped_iters_key]
545+
total_loss_dict[skipped_iters_key]
543546

544547
# Tensorboard values.
545-
if writer and (iteration % args.tensorboard_log_interval == 0) and \
548+
if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
546549
is_last_rank():
547550
if args.log_learning_rate_to_tensorboard:
548551
writer.add_scalar('learning-rate', learning_rate, iteration)
@@ -553,7 +556,7 @@ def add_to_logging(name):
553556
writer.add_scalar('batch-size vs samples', batch_size,
554557
args.consumed_train_samples)
555558
for key in loss_dict:
556-
writer.add_scalar(key, loss_dict[key], iteration)
559+
writer.add_scalar(key , loss_dict[key], iteration)
557560
writer.add_scalar(key + ' vs samples', loss_dict[key],
558561
args.consumed_train_samples)
559562
if args.log_loss_scale_to_tensorboard:
@@ -595,7 +598,7 @@ def add_to_logging(name):
595598
if key not in [advanced_iters_key, skipped_iters_key,
596599
nan_iters_key]:
597600
avg = total_loss_dict[key].item() / \
598-
float(max(1, total_loss_dict[advanced_iters_key]))
601+
float(max(1, total_loss_dict[advanced_iters_key]))
599602
if avg > 0.0:
600603
log_string += ' {}: {:.6E} |'.format(key, avg)
601604
total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
@@ -663,10 +666,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
663666
if args.deepspeed:
664667
# inform deepspeed of any batch size changes
665668
global_batch_size = mpu.get_data_parallel_world_size() * \
666-
args.micro_batch_size * \
667-
get_num_microbatches()
669+
args.micro_batch_size * \
670+
get_num_microbatches()
668671
model[0].set_train_batch_size(global_batch_size)
669672

673+
670674
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
671675
train_step(forward_step_func,
672676
train_data_iterator,
@@ -675,8 +679,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
675679
lr_scheduler)
676680
iteration += 1
677681
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
678-
args.micro_batch_size * \
679-
get_num_microbatches()
682+
args.micro_batch_size * \
683+
get_num_microbatches()
680684

681685
# Logging.
682686
if args.deepspeed:
@@ -739,6 +743,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
739743
print_datetime('exiting program at iteration {}'.format(iteration))
740744
sys.exit()
741745

746+
742747
return iteration
743748

744749

@@ -767,17 +772,17 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
767772
forward_backward_func = forward_backward_pipelining_without_interleaving
768773
else:
769774
forward_backward_func = forward_backward_no_pipelining
770-
775+
771776
if args.deepspeed:
772777
# DeepSpeed uses eval_batch() and already aggregates losses.
773778
assert isinstance(model, list) and len(model) == 1
774779
loss = model[0].eval_batch(data_iterator)
775-
loss_dicts = [{'lm loss': loss}] * get_num_microbatches()
780+
loss_dicts = [{'lm loss' : loss}] * get_num_microbatches()
776781
else:
777782
loss_dicts = forward_backward_func(
778783
forward_step_func, data_iterator, model, optimizer=None,
779784
timers=None, forward_only=True)
780-
785+
781786
if mpu.is_pipeline_last_stage(ignore_virtual=True):
782787
# Reduce across processes.
783788
for loss_dict in loss_dicts:
@@ -786,8 +791,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
786791
key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
787792

788793
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
789-
* args.micro_batch_size \
790-
* get_num_microbatches()
794+
* args.micro_batch_size \
795+
* get_num_microbatches()
791796
# Move model back to the train mode.
792797
for model_module in model:
793798
model_module.train()
@@ -797,7 +802,6 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
797802

798803
return total_loss_dict
799804

800-
801805
def evaluate_and_print_results(prefix, forward_step_func,
802806
data_iterator, model,
803807
iteration, verbose=False):
@@ -835,7 +839,6 @@ def cyclic_iter(iter):
835839
for x in iter:
836840
yield x
837841

838-
839842
def build_train_valid_test_data_iterators(
840843
build_train_valid_test_datasets_provider):
841844
"""XXX"""
@@ -865,7 +868,7 @@ def build_train_valid_test_data_iterators(
865868
else:
866869
train_samples = args.train_iters * args.global_batch_size
867870
eval_iters = (args.train_iters // args.eval_interval + 1) * \
868-
args.eval_iters
871+
args.eval_iters
869872
test_iters = args.eval_iters
870873
train_val_test_num_samples = [train_samples,
871874
eval_iters * args.global_batch_size,
@@ -904,25 +907,26 @@ def build_train_valid_test_data_iterators(
904907
args.do_valid = flags[1].item()
905908
args.do_test = flags[2].item()
906909

910+
907911
# Build iterators.
908912
dl_type = args.dataloader_type
909913
assert dl_type in ['single', 'cyclic']
910914

911915
if train_dataloader is not None:
912916
train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
913-
else iter(cyclic_iter(train_dataloader))
917+
else iter(cyclic_iter(train_dataloader))
914918
else:
915919
train_data_iterator = None
916920

917921
if valid_dataloader is not None:
918922
valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
919-
else iter(cyclic_iter(valid_dataloader))
923+
else iter(cyclic_iter(valid_dataloader))
920924
else:
921925
valid_data_iterator = None
922926

923927
if test_dataloader is not None:
924928
test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
925-
else iter(cyclic_iter(test_dataloader))
929+
else iter(cyclic_iter(test_dataloader))
926930
else:
927931
test_data_iterator = None
928932

0 commit comments

Comments
 (0)