Skip to content

Commit da31db6

Browse files
authored
Fix deepspeed prefix-lm (#107)
* Fix pretrain prefix lm using deepspeed * Fix: self._args to args * First set attn_mask in model and then build model * Fix: enforce that we pass down tuple instead of generator * Attention mask does not need to be transposed * BIGGEST HACK EVER * Remove BIGGEST HACK * Skip prefix test as PP>1 doesn't work yet on deepspeed * Unskip prefix test * Merge branch 'main' into thomas/fix_deepspeed_prefix
1 parent b5098e6 commit da31db6

File tree

4 files changed

+132
-11
lines changed

4 files changed

+132
-11
lines changed

megatron/model/gpt_model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,17 @@ def _to_float16(inputs):
207207
tied_weight_attr='word_embeddings_weight'))
208208

209209
if args.fp32_residual_connection:
210-
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
210+
if hasattr(args, 'attn_mask'):
211+
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
212+
else:
213+
# EmbeddingPipe returns attention mask as well
214+
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:]))
211215
else:
212-
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
216+
if hasattr(args, 'attn_mask'):
217+
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
218+
else:
219+
# EmbeddingPipe returns attention mask as well
220+
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:]))
213221

214222
for layer_idx in range(args.num_layers):
215223
self.specs.append(
@@ -222,6 +230,10 @@ def _to_float16(inputs):
222230
self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal))
223231

224232

233+
if not hasattr(args, 'attn_mask'):
234+
# We drop attention mask from the pipeline
235+
self.specs.append(lambda x: x[0])
236+
225237
# Undo data format change
226238
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
227239

megatron/model/language_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ def forward(self, inputs, **kwargs):
290290
if hasattr(self._args, 'attn_mask'):
291291
return embeddings
292292
else:
293-
assert False
294293
return embeddings, attention_mask
295294

296295

pretrain_gpt.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ def model_provider(pre_process=True, post_process=True):
4848
enabled=args.zero_stage == 3,
4949
mpu=mpu):
5050
if args.deepspeed:
51-
model = GPTModelPipe(
52-
num_tokentypes=0,
53-
parallel_output=True
54-
)
55-
# This is a hack to give us a reference to get_batch_pipe from within training.py
56-
# We need to call model.set_batch_fn after deepspeed.initialize
57-
model._megatron_batch_fn = get_batch_pipe
58-
5951
# Precompute the attention mask and store it in args. This avoids having to
6052
# pipeline it as an activation during training. The mask is constant, and thus
6153
# we can reuse it.
@@ -73,6 +65,14 @@ def model_provider(pre_process=True, post_process=True):
7365
# must be bool or the training crashes expecting bool, but getting Half
7466
args.attn_mask = attention_mask.to(torch.bool)
7567
args.attn_mask_original = attention_mask.to(torch.bool)
68+
69+
model = GPTModelPipe(
70+
num_tokentypes=0,
71+
parallel_output=True
72+
)
73+
# This is a hack to give us a reference to get_batch_pipe from within training.py
74+
# We need to call model.set_batch_fn after deepspeed.initialize
75+
model._megatron_batch_fn = get_batch_pipe
7676
else:
7777
model = GPTModel(
7878
num_tokentypes=0,

tests/test_training.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,113 @@ def test_training_all(self, variation):
266266
# test tensorboard (1 file from the first run, plus 1 now)
267267
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
268268
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")
269+
270+
def test_training_prefix_lm_all(self):
271+
# all in one test
272+
src_dir = self.src_dir
273+
data_dir = f"{self.data_dir}/gpt2"
274+
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False)
275+
276+
pp_size, tp_size, dp_size = get_3d_dimensions()
277+
num_gpus = pp_size * tp_size * dp_size
278+
279+
n_samples = 200 # about 37 iterations
280+
exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume
281+
args = f"""
282+
--tensor-model-parallel-size {tp_size}
283+
--pipeline-model-parallel-size {pp_size}
284+
--distributed-backend nccl
285+
286+
--num-layers 2
287+
--hidden-size 64
288+
--num-attention-heads 2
289+
--seq-length 128
290+
--max-position-embeddings 1024
291+
--micro-batch-size 1
292+
--rampup-batch-size 2 2 {n_samples}
293+
--global-batch-size 16
294+
--train-samples {n_samples}
295+
--loss-on-targets-only
296+
297+
--optimizer adam
298+
--adam-beta1 0.9
299+
--adam-beta2 0.95
300+
--adam-eps 1e-8
301+
--lr 1e-4
302+
--lr-warmup-samples 5
303+
--clip-grad 1.0
304+
--weight-decay 1e-1
305+
--fp16
306+
307+
--log-interval 5
308+
--save-interval 10
309+
--eval-interval 10
310+
--eval-iters 5
311+
--checkpoint-activations
312+
--glu-activation geglu
313+
--exit-interval {exit_interval}
314+
315+
--merge-file {data_dir}/gpt2-tiny-merges.txt
316+
--vocab-file {data_dir}/gpt2-tiny-vocab.json
317+
--save {output_dir}/checkpoints
318+
--load {output_dir}/checkpoints
319+
--data-path {data_dir}/meg-gpt2-openwebtext_text_document
320+
--codecarbon-dir {output_dir}/codecarbon
321+
--tensorboard-dir {output_dir}/tensorboard
322+
--tensorboard-queue-size 5
323+
--log-timers-to-tensorboard
324+
--log-batch-size-to-tensorboard
325+
--log-validation-ppl-to-tensorboard
326+
""".split()
327+
328+
ds_args = f"""
329+
--deepspeed
330+
--deepspeed_config {self.test_file_dir_str}/ds_config.json
331+
--zero-stage 1
332+
--deepspeed-activation-checkpointing
333+
""".split()
334+
335+
script = [f"{src_dir}/pretrain_prefix_lm.py"]
336+
launcher = get_launcher(num_gpus)
337+
338+
cmd = launcher + script + args + ds_args
339+
# keep for quick debug
340+
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
341+
342+
# 1. test training from scratch (no checkpoint)
343+
with CaptureStdout() as cs:
344+
execute_subprocess_async(cmd, env=self.get_env())
345+
346+
# test deepspeed is running
347+
self.assertIn("DeepSpeed info", cs.out)
348+
349+
# test reports
350+
self.assertIn("consumed samples", cs.out)
351+
352+
# test there should be no checkpoint this round
353+
self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out)
354+
355+
# test checkpoint saving
356+
self.assertIn("successfully saved checkpoint at iteration", cs.out)
357+
358+
# test tensorboard
359+
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
360+
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")
361+
362+
# 2. test training from checkpoint: resume
363+
# now do it again, this time resuming from the checkpoint
364+
with CaptureStdout() as cs:
365+
execute_subprocess_async(cmd, env=self.get_env())
366+
367+
# test checkpoint loading
368+
self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out)
369+
370+
# test reports
371+
self.assertIn("consumed samples", cs.out)
372+
373+
# test checkpoint saving
374+
self.assertIn("successfully saved checkpoint at iteration", cs.out)
375+
376+
# test tensorboard (1 file from the first run, plus 1 now)
377+
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
378+
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")

0 commit comments

Comments
 (0)