Skip to content

Commit

Permalink
Fix correctness issue with megatron save/load checkpoints (#1386)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Fixes facebookresearch/fairseq#2681.

Proof that it's working now:
```
python fairseq_train.py --task masked_lm /checkpoint/bioseq_nonsecure/model-parallel-data/tiny_sample_valid_ur50-bin  --dataset-impl fasta  --save-dir checkpoints/mp-fix4    --dropout 0.1   --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0   --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07   --tokens-per-sample 128 --sample-break-mode none   --max-tokens 128 --no-progress-bar --log-interval 1 --seed 4 --max-epoch 1 --max-update 50 --encoder-layers 4  --arch model_parallel_roberta_large --model-parallel-size 2 --update-freq 2 --save-interval-updates 10

2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     11 / 78 loss=0.939, ppl=1.92, wps=116.7, ups=0.11, wpb=1024, bsz=8, num_updates=11, lr=1.47473e-06, gnorm=2.276, train_wall=0, wall=15
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     12 / 78 loss=0.938, ppl=1.92, wps=15769.2, ups=15.38, wpb=1024, bsz=8, num_updates=12, lr=1.5997e-06, gnorm=2.612, train_wall=0, wall=15
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     13 / 78 loss=0.877, ppl=1.84, wps=18658.8, ups=18.2, wpb=1024, bsz=8, num_updates=13, lr=1.72468e-06, gnorm=2.798, train_wall=0, wall=15
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     14 / 78 loss=0.887, ppl=1.85, wps=18324.5, ups=17.88, wpb=1024, bsz=8, num_updates=14, lr=1.84965e-06, gnorm=2.326, train_wall=0, wall=15
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     15 / 78 loss=0.867, ppl=1.82, wps=17616.5, ups=17.19, wpb=1024, bsz=8, num_updates=15, lr=1.97463e-06, gnorm=2.112, train_wall=0, wall=15
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     16 / 78 loss=0.891, ppl=1.85, wps=18624.5, ups=18.17, wpb=1024, bsz=8, num_updates=16, lr=2.0996e-06, gnorm=2.123, train_wall=0, wall=16
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     17 / 78 loss=0.887, ppl=1.85, wps=17972.5, ups=17.53, wpb=1024, bsz=8, num_updates=17, lr=2.22458e-06, gnorm=2.061, train_wall=0, wall=16
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     18 / 78 loss=0.862, ppl=1.82, wps=14672.4, ups=14.32, wpb=1024, bsz=8, num_updates=18, lr=2.34955e-06, gnorm=2.282, train_wall=0, wall=16
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     19 / 78 loss=0.876, ppl=1.83, wps=14398.6, ups=14.05, wpb=1024, bsz=8, num_updates=19, lr=2.47453e-06, gnorm=2.261, train_wall=0, wall=16
2020-10-29 18:42:08 | INFO | train_inner | epoch 001:     20 / 78 loss=0.818, ppl=1.76, wps=18652.2, ups=18.2, wpb=1024, bsz=8, num_updates=20, lr=2.5995e-06, gnorm=1.969, train_wall=0, wall=16

...relaunch...

2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     11 / 78 loss=0.939, ppl=1.92, wps=98.2, ups=0.1, wpb=1024, bsz=8, num_updates=11, lr=1.47473e-06, gnorm=2.276, train_wall=1, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     12 / 78 loss=0.938, ppl=1.92, wps=17137.8, ups=16.72, wpb=1024, bsz=8, num_updates=12, lr=1.5997e-06, gnorm=2.612, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     13 / 78 loss=0.877, ppl=1.84, wps=17239.6, ups=16.82, wpb=1024, bsz=8, num_updates=13, lr=1.72468e-06, gnorm=2.798, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     14 / 78 loss=0.887, ppl=1.85, wps=18132, ups=17.69, wpb=1024, bsz=8, num_updates=14, lr=1.84965e-06, gnorm=2.326, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     15 / 78 loss=0.867, ppl=1.82, wps=17795.1, ups=17.36, wpb=1024, bsz=8, num_updates=15, lr=1.97463e-06, gnorm=2.112, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     16 / 78 loss=0.891, ppl=1.85, wps=18021.3, ups=17.58, wpb=1024, bsz=8, num_updates=16, lr=2.0996e-06, gnorm=2.123, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     17 / 78 loss=0.887, ppl=1.85, wps=16452.9, ups=16.05, wpb=1024, bsz=8, num_updates=17, lr=2.22458e-06, gnorm=2.061, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     18 / 78 loss=0.862, ppl=1.82, wps=17563.3, ups=17.14, wpb=1024, bsz=8, num_updates=18, lr=2.34955e-06, gnorm=2.282, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     19 / 78 loss=0.876, ppl=1.83, wps=16770.3, ups=16.36, wpb=1024, bsz=8, num_updates=19, lr=2.47453e-06, gnorm=2.261, train_wall=0, wall=0
2020-10-29 18:47:20 | INFO | train_inner | epoch 001:     20 / 78 loss=0.818, ppl=1.76, wps=16808.2, ups=16.4, wpb=1024, bsz=8, num_updates=20, lr=2.5995e-06, gnorm=1.969, train_wall=0, wall=0
```

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: fairinternal/fairseq-py#1386

Reviewed By: myleott

Differential Revision: D24640946

Pulled By: joshim5

fbshipit-source-id: cb141d92496b289a04d53f080ecd4d5ac6941672
  • Loading branch information
joshim5 authored and facebook-github-bot committed Nov 3, 2020
1 parent de97773 commit b120fbb
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion fairseq/model_parallel/megatron_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from fairseq.trainer import Trainer
from fairseq.dataclass.configs import FairseqConfig


try:
from fairseq.model_parallel.megatron.mpu import (
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_model_parallel_group,
get_model_parallel_src_rank,
get_cuda_rng_tracker,
)

has_megatron_submodule = True
Expand Down Expand Up @@ -65,3 +65,23 @@ def _aggregate_model_parallel_grad_norm(total_norm):
clip_norm,
aggregate_norm_fn=_aggregate_model_parallel_grad_norm,
)

def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['rng_tracker_states'] \
= get_cuda_rng_tracker().get_states()
super().save_checkpoint(filename, extra_state)

def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters)
if extra_state is not None and 'rng_tracker_states' in extra_state:
get_cuda_rng_tracker().set_states(
extra_state['rng_tracker_states'])
return extra_state

0 comments on commit b120fbb

Please sign in to comment.