From 5e24982e580ab32a779d7271ad8cb46dc5c6475f Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 28 Oct 2020 18:59:14 +0000 Subject: [PATCH] Upgrade PyTorch Lightning to 1.0.2 (#7852) Co-authored-by: Sam Shleifer --- examples/lightning_base.py | 5 +++-- examples/requirements.txt | 2 +- examples/seq2seq/callbacks.py | 1 - examples/seq2seq/finetune.py | 5 ++--- examples/seq2seq/test_bash_script.py | 2 +- examples/seq2seq/test_seq2seq_examples_multi_gpu.py | 1 - examples/text-classification/run_pl_glue.py | 2 +- examples/token-classification/run_pl_ner.py | 6 +++--- 8 files changed, 11 insertions(+), 13 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 6ff4a08fc4ac9d..739e5dc59650dd 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None: def generic_train( model: BaseTransformer, args: argparse.Namespace, - early_stopping_callback=False, + early_stopping_callback=None, logger=True, # can pass WandbLogger() here extra_callbacks=[], checkpoint_callback=None, @@ -355,6 +355,8 @@ def generic_train( checkpoint_callback = pl.callbacks.ModelCheckpoint( filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 ) + if early_stopping_callback: + extra_callbacks.append(early_stopping_callback) if logging_callback is None: logging_callback = LoggingCallback() @@ -376,7 +378,6 @@ def generic_train( callbacks=[logging_callback] + extra_callbacks, logger=logger, checkpoint_callback=checkpoint_callback, - early_stop_callback=early_stopping_callback, **train_params, ) diff --git a/examples/requirements.txt b/examples/requirements.txt index 120a3ab5e06cfa..9c270479678916 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -5,7 +5,7 @@ psutil sacrebleu rouge-score tensorflow_datasets -pytorch-lightning==0.9.0 +pytorch-lightning==1.0.4 matplotlib git-python==1.0.3 faiss-cpu diff --git a/examples/seq2seq/callbacks.py b/examples/seq2seq/callbacks.py index c6cd2014ded49b..64560487496dcf 100644 --- a/examples/seq2seq/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa monitor=f"val_{metric}", mode="min" if "loss" in metric else "max", save_top_k=save_top_k, - period=0, # maybe save a checkpoint every time val is run, not just end of epoch. ) return checkpoint_callback diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 9da761db73b42e..7e57f7ba40f1c4 100755 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -182,7 +182,6 @@ def validation_step(self, batch, batch_idx) -> Dict: return self._generative_step(batch) def validation_epoch_end(self, outputs, prefix="val") -> Dict: - self.step_count += 1 losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} loss = losses["loss"] @@ -252,7 +251,7 @@ def get_dataset(self, type_path) -> Seq2SeqDataset: def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: dataset = self.get_dataset(type_path) - if self.hparams.sortish_sampler and type_path != "test": + if self.hparams.sortish_sampler and type_path != "test" and type_path != "val": sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1) return DataLoader( dataset, @@ -263,7 +262,7 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) sampler=sampler, ) - elif self.hparams.max_tokens_per_batch is not None and type_path != "test": + elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val": batch_sampler = dataset.make_dynamic_sampler( self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1 ) diff --git a/examples/seq2seq/test_bash_script.py b/examples/seq2seq/test_bash_script.py index 24ce9bfe6b49c5..71861ef4dbc6a3 100644 --- a/examples/seq2seq/test_bash_script.py +++ b/examples/seq2seq/test_bash_script.py @@ -144,6 +144,7 @@ def test_opus_mt_distill_script(self): f"--num_train_epochs={epochs}", "--warmup_steps=10", "--val_check_interval=1.0", + "--do_predict", ] ) with patch.object(sys, "argv", testargs): @@ -151,7 +152,6 @@ def test_opus_mt_distill_script(self): parser = pl.Trainer.add_argparse_args(parser) parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() - args.do_predict = False # assert args.gpus == gpus THIS BREAKS for multigpu model = distill_main(args) diff --git a/examples/seq2seq/test_seq2seq_examples_multi_gpu.py b/examples/seq2seq/test_seq2seq_examples_multi_gpu.py index a6b76a4c530a6f..03ec39037c15b4 100644 --- a/examples/seq2seq/test_seq2seq_examples_multi_gpu.py +++ b/examples/seq2seq/test_seq2seq_examples_multi_gpu.py @@ -176,7 +176,6 @@ def convert(k, v): print(metrics) last_step_stats = metrics["val"][-1] self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01) - self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"]) self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float) self.assertEqual(len(metrics["test"]), 1) desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1) diff --git a/examples/text-classification/run_pl_glue.py b/examples/text-classification/run_pl_glue.py index 80315abc56bbb9..500a0bd627643d 100644 --- a/examples/text-classification/run_pl_glue.py +++ b/examples/text-classification/run_pl_glue.py @@ -192,7 +192,7 @@ def main(): # Optionally, predict on dev set and write to output_dir if args.do_predict: - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))) model = model.load_from_checkpoint(checkpoints[-1]) return trainer.test(model) diff --git a/examples/token-classification/run_pl_ner.py b/examples/token-classification/run_pl_ner.py index c82cff74d8ef4c..1066c6fed48cc9 100644 --- a/examples/token-classification/run_pl_ner.py +++ b/examples/token-classification/run_pl_ner.py @@ -207,9 +207,9 @@ def add_model_specific_args(parser, root_dir): if args.do_predict: # See https://github.com/huggingface/transformers/issues/3159 - # pl use this format to create a checkpoint: + # pl use this default format to create a checkpoint: # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ - # /pytorch_lightning/callbacks/model_checkpoint.py#L169 - checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) + # /pytorch_lightning/callbacks/model_checkpoint.py#L322 + checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True))) model = model.load_from_checkpoint(checkpoints[-1]) trainer.test(model)