Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] SummarizationModule improvements #4951

Merged
merged 255 commits into from
Jun 17, 2020

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Jun 12, 2020

This PR makes the SummarizationTrainer much more usable, and when improvements are not unique to summarization, they are implemented in lightning_base.py instead.

  • Checkpointing Before this PR, the code saves 5GB of PL checkpoints per epoch, now SummarizationTrainer saves the best checkpoint based on ROUGE 2 score, and also saves it in huggingface save_pretrained format using the on_save_checkpoint. This will help resolve lots of confusion in various issues about how to load the pl checkpoints.

The current summarization code can only accept bs=1 and takes 24h to run 1 epoch on CNN DM. With the following changes, you can train much faster, if you wish. The docs suggested that larger batch sizes were possible with default params, which is fixed.

Changes to Allow Faster Summarization Training

these are all optional and turned off by default

  1. freezing: before this PR, it was basically only possible to finetune with batchsize 2-4 on a 16GB system. With --freeze_embeds and --freeze_encoder, you can get batch size MUCH higher, towards 32. I've seen strong results with these options.

  2. On CNNDM and XSUM the datasets are 200K examples, and epochs are very long. For this reason it is preferable to run validation (and get a rouge score) more frequently, but with previous params each validation_step took 1hr. By passing --n_val=1000 --val_check_interval=0.25, you can run validation 4x per epoch and it only takes 3 minutes. I also allows the config's beam search parameters to be used, rather than hardcoding faster but lower scoring ones.

  3. {train|val|test}_max_target_length: I have found it preferable to truncate train summaries to 56 for XSUM and CNNDM respectively, but doing this for val/test artificially inflates rouge scores. So these clargs are separated.

Changes to lightning_base

  • Number of trainable parameters and total parameters are logged by default.
  • All possible pl.Trainer clargs are passed through add_generic_args (Inspired by @nateraw)

WandbLogger

  • --logger wandb will instantiate a default wandb logger.
  • --logger wandb_shared will post results to here, so that the community can compare hyperparameter settings empirically.
  • the default logger is still tensorboard logger because it doesn't require making an account.

Distillation

  • SummarizationDistiller and T5SummarizationDistiller are checked in. This code was sent to me by a researcher who wishes to remain anonymous. DM to discuss.

@sshleifer sshleifer requested a review from julien-c June 14, 2020 23:09
@sshleifer sshleifer linked an issue Jun 15, 2020 that may be closed by this pull request
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks very cool, looking forward to using it!

examples/lightning_base.py Show resolved Hide resolved
examples/summarization/README.md Outdated Show resolved Hide resolved
@sshleifer
Copy link
Contributor Author

Merging now. Happy to address post-merge comments!

@sshleifer sshleifer changed the title [examples] SummarizationTrainer improvements [examples] SummarizationModule improvements Jun 17, 2020
@sshleifer sshleifer merged commit 043f9f5 into huggingface:master Jun 17, 2020
@sshleifer sshleifer deleted the distilbart-clean branch June 17, 2020 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Examples Which is related to examples in general seq2seq
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use finetuned-BART large to do conditional generation
2 participants