Skip to content

Commit 5e24982

Browse files
Sean Narensshleifer
Sean Naren
andauthored
Upgrade PyTorch Lightning to 1.0.2 (#7852)
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
1 parent 1b6c8d4 commit 5e24982

File tree

8 files changed

+11
-13
lines changed

8 files changed

+11
-13
lines changed

examples/lightning_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
337337
def generic_train(
338338
model: BaseTransformer,
339339
args: argparse.Namespace,
340-
early_stopping_callback=False,
340+
early_stopping_callback=None,
341341
logger=True, # can pass WandbLogger() here
342342
extra_callbacks=[],
343343
checkpoint_callback=None,
@@ -355,6 +355,8 @@ def generic_train(
355355
checkpoint_callback = pl.callbacks.ModelCheckpoint(
356356
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
357357
)
358+
if early_stopping_callback:
359+
extra_callbacks.append(early_stopping_callback)
358360
if logging_callback is None:
359361
logging_callback = LoggingCallback()
360362

@@ -376,7 +378,6 @@ def generic_train(
376378
callbacks=[logging_callback] + extra_callbacks,
377379
logger=logger,
378380
checkpoint_callback=checkpoint_callback,
379-
early_stop_callback=early_stopping_callback,
380381
**train_params,
381382
)
382383

examples/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ psutil
55
sacrebleu
66
rouge-score
77
tensorflow_datasets
8-
pytorch-lightning==0.9.0
8+
pytorch-lightning==1.0.4
99
matplotlib
1010
git-python==1.0.3
1111
faiss-cpu

examples/seq2seq/callbacks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
102102
monitor=f"val_{metric}",
103103
mode="min" if "loss" in metric else "max",
104104
save_top_k=save_top_k,
105-
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
106105
)
107106
return checkpoint_callback
108107

examples/seq2seq/finetune.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def validation_step(self, batch, batch_idx) -> Dict:
182182
return self._generative_step(batch)
183183

184184
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
185-
186185
self.step_count += 1
187186
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
188187
loss = losses["loss"]
@@ -252,7 +251,7 @@ def get_dataset(self, type_path) -> Seq2SeqDataset:
252251
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
253252
dataset = self.get_dataset(type_path)
254253

255-
if self.hparams.sortish_sampler and type_path != "test":
254+
if self.hparams.sortish_sampler and type_path != "test" and type_path != "val":
256255
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
257256
return DataLoader(
258257
dataset,
@@ -263,7 +262,7 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False)
263262
sampler=sampler,
264263
)
265264

266-
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
265+
elif self.hparams.max_tokens_per_batch is not None and type_path != "test" and type_path != "val":
267266
batch_sampler = dataset.make_dynamic_sampler(
268267
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
269268
)

examples/seq2seq/test_bash_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ def test_opus_mt_distill_script(self):
144144
f"--num_train_epochs={epochs}",
145145
"--warmup_steps=10",
146146
"--val_check_interval=1.0",
147+
"--do_predict",
147148
]
148149
)
149150
with patch.object(sys, "argv", testargs):
150151
parser = argparse.ArgumentParser()
151152
parser = pl.Trainer.add_argparse_args(parser)
152153
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
153154
args = parser.parse_args()
154-
args.do_predict = False
155155
# assert args.gpus == gpus THIS BREAKS for multigpu
156156

157157
model = distill_main(args)

examples/seq2seq/test_seq2seq_examples_multi_gpu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def convert(k, v):
176176
print(metrics)
177177
last_step_stats = metrics["val"][-1]
178178
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
179-
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
180179
self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
181180
self.assertEqual(len(metrics["test"]), 1)
182181
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)

examples/text-classification/run_pl_glue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def main():
192192

193193
# Optionally, predict on dev set and write to output_dir
194194
if args.do_predict:
195-
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
195+
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
196196
model = model.load_from_checkpoint(checkpoints[-1])
197197
return trainer.test(model)
198198

examples/token-classification/run_pl_ner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def add_model_specific_args(parser, root_dir):
207207

208208
if args.do_predict:
209209
# See https://github.com/huggingface/transformers/issues/3159
210-
# pl use this format to create a checkpoint:
210+
# pl use this default format to create a checkpoint:
211211
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
212-
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
213-
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
212+
# /pytorch_lightning/callbacks/model_checkpoint.py#L322
213+
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-epoch=*.ckpt"), recursive=True)))
214214
model = model.load_from_checkpoint(checkpoints[-1])
215215
trainer.test(model)

0 commit comments

Comments
 (0)