Skip to content
16 changes: 14 additions & 2 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,17 @@ def _to_float16(inputs):
tied_weight_attr='word_embeddings_weight'))

if args.fp32_residual_connection:
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
if hasattr(args, 'attn_mask'):
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
else:
# EmbeddingPipe returns attention mask as well
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:]))
else:
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
if hasattr(args, 'attn_mask'):
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
else:
# EmbeddingPipe returns attention mask as well
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:]))

for layer_idx in range(args.num_layers):
self.specs.append(
Expand All @@ -222,6 +230,10 @@ def _to_float16(inputs):
self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal))


if not hasattr(args, 'attn_mask'):
# We drop attention mask from the pipeline
self.specs.append(lambda x: x[0])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we drop attention here ....


# Undo data format change
self.specs.append(lambda x: x.transpose(0, 1).contiguous())

Expand Down
1 change: 0 additions & 1 deletion megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def forward(self, inputs, **kwargs):
if hasattr(self._args, 'attn_mask'):
return embeddings
else:
assert False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We remove this in order to allow this case.

return embeddings, attention_mask


Expand Down
16 changes: 8 additions & 8 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,6 @@ def model_provider(pre_process=True, post_process=True):
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe

# Precompute the attention mask and store it in args. This avoids having to
# pipeline it as an activation during training. The mask is constant, and thus
# we can reuse it.
Expand All @@ -73,6 +65,14 @@ def model_provider(pre_process=True, post_process=True):
# must be bool or the training crashes expecting bool, but getting Half
args.attn_mask = attention_mask.to(torch.bool)
args.attn_mask_original = attention_mask.to(torch.bool)

model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
Comment on lines +69 to +75
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We move this part of the code after setting attention mask in args. This allows us to use args while building the model, typically to distinguish gpt vs prefix-lm.

else:
model = GPTModel(
num_tokentypes=0,
Expand Down
110 changes: 110 additions & 0 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,113 @@ def test_training_all(self, variation):
# test tensorboard (1 file from the first run, plus 1 now)
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")

def test_training_prefix_lm_all(self):
# all in one test
src_dir = self.src_dir
data_dir = f"{self.data_dir}/gpt2"
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False)

pp_size, tp_size, dp_size = get_3d_dimensions()
num_gpus = pp_size * tp_size * dp_size

n_samples = 200 # about 37 iterations
exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume
args = f"""
--tensor-model-parallel-size {tp_size}
--pipeline-model-parallel-size {pp_size}
--distributed-backend nccl

--num-layers 2
--hidden-size 64
--num-attention-heads 2
--seq-length 128
--max-position-embeddings 1024
--micro-batch-size 1
--rampup-batch-size 2 2 {n_samples}
--global-batch-size 16
--train-samples {n_samples}
--loss-on-targets-only

--optimizer adam
--adam-beta1 0.9
--adam-beta2 0.95
--adam-eps 1e-8
--lr 1e-4
--lr-warmup-samples 5
--clip-grad 1.0
--weight-decay 1e-1
--fp16

--log-interval 5
--save-interval 10
--eval-interval 10
--eval-iters 5
--checkpoint-activations
--glu-activation geglu
--exit-interval {exit_interval}

--merge-file {data_dir}/gpt2-tiny-merges.txt
--vocab-file {data_dir}/gpt2-tiny-vocab.json
--save {output_dir}/checkpoints
--load {output_dir}/checkpoints
--data-path {data_dir}/meg-gpt2-openwebtext_text_document
--codecarbon-dir {output_dir}/codecarbon
--tensorboard-dir {output_dir}/tensorboard
--tensorboard-queue-size 5
--log-timers-to-tensorboard
--log-batch-size-to-tensorboard
--log-validation-ppl-to-tensorboard
""".split()

ds_args = f"""
--deepspeed
--deepspeed_config {self.test_file_dir_str}/ds_config.json
--zero-stage 1
--deepspeed-activation-checkpointing
""".split()

script = [f"{src_dir}/pretrain_prefix_lm.py"]
launcher = get_launcher(num_gpus)

cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die

# 1. test training from scratch (no checkpoint)
with CaptureStdout() as cs:
execute_subprocess_async(cmd, env=self.get_env())

# test deepspeed is running
self.assertIn("DeepSpeed info", cs.out)

# test reports
self.assertIn("consumed samples", cs.out)

# test there should be no checkpoint this round
self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out)

# test checkpoint saving
self.assertIn("successfully saved checkpoint at iteration", cs.out)

# test tensorboard
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")

# 2. test training from checkpoint: resume
# now do it again, this time resuming from the checkpoint
with CaptureStdout() as cs:
execute_subprocess_async(cmd, env=self.get_env())

# test checkpoint loading
self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out)

# test reports
self.assertIn("consumed samples", cs.out)

# test checkpoint saving
self.assertIn("successfully saved checkpoint at iteration", cs.out)

# test tensorboard (1 file from the first run, plus 1 now)
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")