Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples] SummarizationModule improvements #4951

Merged
merged 255 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
255 commits
Select commit Hold shift + click to select a range
71850ad
copy decoder layers
sshleifer May 18, 2020
b01f1d5
Failing test
sshleifer May 18, 2020
937d3d6
Merge branch 'master' into distilbart
sshleifer May 18, 2020
fcc49a0
can import
sshleifer May 18, 2020
4f5790f
passing
sshleifer May 18, 2020
4be1287
real data test passes
sshleifer May 19, 2020
8afb88c
relatif importschlossen
sshleifer May 19, 2020
d2cc12b
boom boom
sshleifer May 19, 2020
1edc50f
bash
sshleifer May 19, 2020
174aaf3
Fast dev run
sshleifer May 19, 2020
5a35811
boom boom
sshleifer May 19, 2020
6302fb0
bs=8
sshleifer May 19, 2020
5a3ed99
Cache tokenized
sshleifer May 19, 2020
70cf536
Merge branch 'distilbart' of github.com:sshleifer/transformers_fork i…
sshleifer May 19, 2020
ca9c685
rouge
sshleifer May 19, 2020
7081d0f
add rouge
sshleifer May 19, 2020
0ed4156
boom boom
sshleifer May 19, 2020
04a8ace
boom boom
sshleifer May 19, 2020
696c8c2
boom boom
sshleifer May 19, 2020
3cebd56
boom boom
sshleifer May 19, 2020
bf5782e
boom boom
sshleifer May 19, 2020
74704ce
boom boom
sshleifer May 19, 2020
bbc4e52
assert student small
sshleifer May 19, 2020
c4530f9
batch
sshleifer May 19, 2020
2ee9388
boom boom
sshleifer May 19, 2020
51221fb
boom boom
sshleifer May 19, 2020
abb81df
boom boom
sshleifer May 19, 2020
bef77fc
boom boom
sshleifer May 19, 2020
2b7132c
metrics saving, but no val_check_interval honored
sshleifer May 20, 2020
2c72948
val check test passing with smaller batch size
sshleifer May 20, 2020
af315c1
boom boom
sshleifer May 20, 2020
5fbf422
summaries for file
sshleifer May 20, 2020
09cc1e6
save preds
sshleifer May 20, 2020
5947812
add files
sshleifer May 20, 2020
59c658d
fixes
sshleifer May 20, 2020
beda65d
bash
sshleifer May 20, 2020
c774160
imports
sshleifer May 20, 2020
c79351a
fix flatten
sshleifer May 20, 2020
05bfb1e
Support fewer layers
sshleifer May 20, 2020
cd461e1
fix
sshleifer May 20, 2020
67f9553
Run encoder once
sshleifer May 20, 2020
f281724
wandb
sshleifer May 20, 2020
1fc653f
log_metrics
sshleifer May 20, 2020
b62bfe6
spelling
sshleifer May 20, 2020
a250d1c
boom boom
sshleifer May 21, 2020
18524d8
boom boom
sshleifer May 21, 2020
a2671a6
boom boom
sshleifer May 21, 2020
fee71b4
npars
sshleifer May 21, 2020
a642076
boom boom
sshleifer May 21, 2020
cf47a6f
boom boom
sshleifer May 21, 2020
ce0de90
boom boom
sshleifer May 21, 2020
b39c39f
boom boom
sshleifer May 21, 2020
7a7dcd3
l2copy[1] = 0
sshleifer May 21, 2020
910da0f
boom boom
sshleifer May 21, 2020
2c2a10e
orig tests pass
sshleifer May 21, 2020
21a2cc0
delete run_distiller -> main
sshleifer May 21, 2020
07da062
avoid losing bart grouped batch sampler work
sshleifer May 21, 2020
dd418e8
boom boom
sshleifer May 21, 2020
ac72c7a
convenience method
sshleifer May 21, 2020
88b4b97
switch copy logic
sshleifer May 21, 2020
3f73cb7
boom boom
sshleifer May 21, 2020
15a2d4e
boom boom
sshleifer May 21, 2020
3461e24
boom boom
sshleifer May 21, 2020
ca82946
boom boom
sshleifer May 21, 2020
f5606df
boom boom
sshleifer May 21, 2020
5d75172
Freeze encoder fix
sshleifer May 21, 2020
aa995f9
freeze after init
sshleifer May 21, 2020
bcb9996
overcautious freezing
sshleifer May 21, 2020
29499c9
boom boom
sshleifer May 21, 2020
705ed2f
Better desc
sshleifer May 21, 2020
db3d85a
boom boom
sshleifer May 21, 2020
c6f7e14
progbar
sshleifer May 21, 2020
fa0eda4
boom boom
sshleifer May 21, 2020
021d2b9
only train
sshleifer May 21, 2020
c779c38
boom boom
sshleifer May 21, 2020
d4830f9
rename
sshleifer May 22, 2020
a78603c
boom boom
sshleifer May 22, 2020
1ca011f
test_mtl=140 default
sshleifer May 22, 2020
8fbdd7d
boom boom
sshleifer May 22, 2020
ef45ee9
boom boom
sshleifer May 22, 2020
d328871
add RougeTracker
sshleifer May 22, 2020
584c329
encoder different
sshleifer May 24, 2020
20abeec
passing
sshleifer May 24, 2020
e48a829
boom boom
sshleifer May 25, 2020
9f21462
boom boom
sshleifer May 25, 2020
6f03e39
boom boom
sshleifer May 25, 2020
2443a7e
boom boom
sshleifer May 25, 2020
4d3617a
boom boom
sshleifer May 25, 2020
6f1757d
boom boom
sshleifer May 25, 2020
af8b962
boom boom
sshleifer May 25, 2020
c7f6e62
boom boom
sshleifer May 25, 2020
c154fb3
style
sshleifer May 25, 2020
a42f09e
warmup_steps=500
sshleifer May 25, 2020
88f9f91
Sortish Sampler
sshleifer May 25, 2020
800dcdc
boom boom
sshleifer May 25, 2020
db52f37
save fewer checkpoints
sshleifer May 25, 2020
f5e697b
Dont require model_name_or_path
sshleifer May 25, 2020
20a5968
passing with enc mse
sshleifer May 26, 2020
12f2d9e
early quitting for encoder only loss
sshleifer May 26, 2020
0933e11
boom boom
sshleifer May 26, 2020
56aedac
boom boom
sshleifer May 26, 2020
63521f1
boom boom
sshleifer May 26, 2020
0c3d22e
boom boom
sshleifer May 26, 2020
81e939f
boom boom
sshleifer May 26, 2020
c10bb9f
boom boom
sshleifer May 26, 2020
d3611c9
Run distiller: no data defaults
sshleifer May 26, 2020
d2cd7b4
Somehow passing with better freezing logic
sshleifer May 26, 2020
e7d9dc3
dont default teacher
sshleifer May 26, 2020
ee18f81
Merge branch 'distilbart' of github.com:sshleifer/transformers_fork i…
sshleifer May 26, 2020
298ec53
--freeze_encoder
sshleifer May 26, 2020
f750d80
boom boom
sshleifer May 26, 2020
9b54356
Logging adjustments
sshleifer May 27, 2020
4f19747
boom boom
sshleifer May 27, 2020
d92a69f
boom boom
sshleifer May 27, 2020
99e7161
boom boom
sshleifer May 27, 2020
7f1af24
freeze decoder
sshleifer May 27, 2020
44743e8
fix ckpt bug maybe
sshleifer May 27, 2020
84ea9fa
Fixed tests
sshleifer May 27, 2020
0c3d6d8
boom boom
sshleifer May 27, 2020
516bbca
boom boom
sshleifer May 27, 2020
175efd8
getstate
sshleifer May 27, 2020
3c34cae
distrib hacks
sshleifer May 27, 2020
3f6463a
attempt logger.name fix
sshleifer May 27, 2020
324d4dc
No wandb if ddp
sshleifer May 27, 2020
eb4b3cc
boom boom
sshleifer May 27, 2020
d53236c
boom boom
sshleifer May 27, 2020
7c52e73
boom boom
sshleifer May 27, 2020
25edee1
boom boom
sshleifer May 27, 2020
6689b59
boom boom
sshleifer May 27, 2020
8fb24ce
boom boom
sshleifer May 27, 2020
bd361ea
rename generic_train -> build_trainer
sshleifer May 31, 2020
105d83d
undo rename
sshleifer May 31, 2020
7b9557f
boom boom
sshleifer May 31, 2020
60a9014
rouge tracker fix
sshleifer Jun 2, 2020
82a6b83
Merge branch 'master' into distilbart
sshleifer Jun 2, 2020
e17df74
merged master
sshleifer Jun 2, 2020
789f06e
less annoying requirements
sshleifer Jun 2, 2020
fb6adb2
backup rouge dfs
sshleifer Jun 3, 2020
b11fc69
boom boom
sshleifer Jun 3, 2020
8c6769e
boom boom
sshleifer Jun 3, 2020
df0b28a
boom boom
sshleifer Jun 4, 2020
8f614af
broken
sshleifer Jun 4, 2020
a2ba031
boom boom
sshleifer Jun 4, 2020
440744a
Distiller brewer does not break
sshleifer Jun 4, 2020
20aedfb
boom boom
sshleifer Jun 4, 2020
3e47e61
Disable cos loss
sshleifer Jun 5, 2020
a65ab52
boom boom
sshleifer Jun 5, 2020
cc83d34
boom boom
sshleifer Jun 5, 2020
0b98750
boom boom
sshleifer Jun 5, 2020
8d4f075
boom boom
sshleifer Jun 5, 2020
24f483c
boom boom
sshleifer Jun 5, 2020
e3a00a4
boom boom
sshleifer Jun 5, 2020
8b98c4c
boom boom
sshleifer Jun 6, 2020
8f68b91
boom boom
sshleifer Jun 6, 2020
055167b
boom boom
sshleifer Jun 6, 2020
e446fd2
boom boom
sshleifer Jun 6, 2020
014c061
boom boom
sshleifer Jun 6, 2020
e146713
passing
sshleifer Jun 6, 2020
9fc692d
boom boom
sshleifer Jun 6, 2020
3abbf78
boom boom
sshleifer Jun 7, 2020
47c226e
Merge branch 'distilbart' of github.com:sshleifer/transformers_fork i…
sshleifer Jun 7, 2020
bb1bffc
Save top1
sshleifer Jun 7, 2020
09f199b
Merge branch 'distilbart' of github.com:sshleifer/transformers_fork i…
sshleifer Jun 7, 2020
3b304af
boom boom
sshleifer Jun 7, 2020
c294e5e
boom boom
sshleifer Jun 7, 2020
7ccd236
Evaluate_checkpoint
sshleifer Jun 8, 2020
b32d283
boom boom
sshleifer Jun 8, 2020
c59be9b
boom boom
sshleifer Jun 8, 2020
ecff6f9
V8 and generate_summaries fixes
sshleifer Jun 8, 2020
e46e101
boom boom
sshleifer Jun 8, 2020
bcb4624
boom boom
sshleifer Jun 8, 2020
3467f1c
boom boom
sshleifer Jun 8, 2020
582efc0
boom boom
sshleifer Jun 8, 2020
9f6596e
boom boom
sshleifer Jun 8, 2020
3c38444
boom boom
sshleifer Jun 8, 2020
630af0c
boom boom
sshleifer Jun 9, 2020
318a585
boom boom
sshleifer Jun 9, 2020
31b4c8c
boom boom
sshleifer Jun 9, 2020
fa7f818
t5 failing
sshleifer Jun 9, 2020
0dc32e6
boom boom
sshleifer Jun 10, 2020
a79f5fb
boom boom
sshleifer Jun 10, 2020
0098872
Deleted SummarizationDistiller
sshleifer Jun 10, 2020
a387c05
boom boom
sshleifer Jun 10, 2020
76377c9
boom boom
sshleifer Jun 10, 2020
315ab6b
boom boom
sshleifer Jun 10, 2020
b5adf48
boom boom
sshleifer Jun 10, 2020
d3b8694
boom boom
sshleifer Jun 10, 2020
dc785f0
boom boom
sshleifer Jun 10, 2020
af05c0e
boom boom
sshleifer Jun 10, 2020
913bb8f
boom boom
sshleifer Jun 10, 2020
d26f5da
stuck on t5 mask
sshleifer Jun 10, 2020
c553e12
t5 works
sshleifer Jun 10, 2020
7f3f9fa
boom boom
sshleifer Jun 10, 2020
6624b63
boom boom
sshleifer Jun 10, 2020
bad1b75
boom boom
sshleifer Jun 10, 2020
a549aaa
boom boom
sshleifer Jun 10, 2020
574ffbb
boom boom
sshleifer Jun 10, 2020
b2b222d
boom boom
sshleifer Jun 10, 2020
a2f2d5d
undo some chg
sshleifer Jun 10, 2020
6aa28bd
ignore student
sshleifer Jun 10, 2020
585bb57
passing cpu
sshleifer Jun 10, 2020
5a28144
fp16_ever=False
sshleifer Jun 11, 2020
e5728e4
boom boom
sshleifer Jun 11, 2020
10b092e
boom boom
sshleifer Jun 11, 2020
ac47dbd
boom boom
sshleifer Jun 11, 2020
c7d05f7
boom boom
sshleifer Jun 11, 2020
4224ac9
boom boom
sshleifer Jun 11, 2020
3f3040a
boom boom
sshleifer Jun 11, 2020
9d99761
passing
sshleifer Jun 12, 2020
bdf0b1a
boom boom
sshleifer Jun 12, 2020
fb6eb09
Failing but minimal
sshleifer Jun 12, 2020
af3f86f
remove some cruft
sshleifer Jun 12, 2020
c5a1de5
boom boom
sshleifer Jun 13, 2020
e67ad41
boom boom
sshleifer Jun 13, 2020
0274338
original tests pass
sshleifer Jun 13, 2020
abee3e2
Merge branch 'master' into distilbart-clean
sshleifer Jun 13, 2020
de7035d
style
sshleifer Jun 13, 2020
e2d4544
style
sshleifer Jun 13, 2020
eab9779
better mask logic
sshleifer Jun 14, 2020
350eeb7
style
sshleifer Jun 14, 2020
3e62d96
add git-python requirement
sshleifer Jun 14, 2020
379e8c7
Cleanup
sshleifer Jun 14, 2020
35a82ee
boom boom
sshleifer Jun 14, 2020
6bf996f
boom boom
sshleifer Jun 14, 2020
5526be3
boom boom
sshleifer Jun 14, 2020
eb84b9d
boom boom
sshleifer Jun 14, 2020
4606bda
Bash cleanup
sshleifer Jun 14, 2020
9f98a6a
docs
sshleifer Jun 14, 2020
969c271
more honest docs
sshleifer Jun 14, 2020
f179e7b
Allow wandb logger
sshleifer Jun 14, 2020
c34b886
Wandb logger
sshleifer Jun 14, 2020
c9597fe
docs
sshleifer Jun 14, 2020
9e95429
docs
sshleifer Jun 14, 2020
99de2c3
pass through logger
sshleifer Jun 14, 2020
6bdfb14
boom boom
sshleifer Jun 14, 2020
5283878
Move stuff to utils
sshleifer Jun 14, 2020
8d867e9
on_save_checkpoint
sshleifer Jun 15, 2020
b7bb7cb
Fix decoder mask
sshleifer Jun 15, 2020
9508199
Better logger name
sshleifer Jun 15, 2020
68f6ccd
Fixed merge conflicts
sshleifer Jun 15, 2020
1ffd6cb
style
sshleifer Jun 15, 2020
0d592c5
fix import
sshleifer Jun 16, 2020
7ff1d78
Merge branch 'master' into distilbart-clean
sshleifer Jun 16, 2020
99a4866
Merge branch 'master' into distilbart-clean
sshleifer Jun 16, 2020
d25442a
rename -> SummarizationModule
sshleifer Jun 16, 2020
b56b4d8
more tips
sshleifer Jun 16, 2020
2deff3b
Fix README
sshleifer Jun 17, 2020
f28fc63
cleanup
sshleifer Jun 17, 2020
0ce2375
Cleanup more
sshleifer Jun 17, 2020
b7e1d5e
indent
sshleifer Jun 17, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 107 additions & 45 deletions examples/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import os
import random
from pathlib import Path
from typing import Any, Dict

import numpy as np
import pytorch_lightning as pl
Expand All @@ -13,10 +15,13 @@
AutoModel,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
get_linear_schedule_with_warmup,
)

Expand All @@ -31,40 +36,68 @@
"pretraining": AutoModelForPreTraining,
"token-classification": AutoModelForTokenClassification,
"language-modeling": AutoModelWithLMHead,
"summarization": AutoModelForSeq2SeqLM,
"translation": AutoModelForSeq2SeqLM,
}


def set_seed(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
if args.gpus > 0:
torch.cuda.manual_seed_all(args.seed)


class BaseTransformer(pl.LightningModule):
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
def __init__(
self,
hparams: argparse.Namespace,
num_labels=None,
mode="base",
config=None,
tokenizer=None,
model=None,
**config_kwargs
):
"Initialize a model."

super().__init__()
self.hparams = hparams
self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir)
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
self.model = MODEL_MODES[mode].from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=self.config,
cache_dir=cache_dir,
)
if config is None:
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
)
else:
self.config: PretrainedConfig = config
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
else:
self.tokenizer: PreTrainedTokenizer = tokenizer
if model is None:
self.model_type = MODEL_MODES[mode]
self.model = self.model_type.from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=self.config,
cache_dir=cache_dir,
)
else:
self.model_type = None
self.model = model

def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)

def is_logger(self):
return self.trainer.proc_rank <= 0
Expand Down Expand Up @@ -138,6 +171,15 @@ def _feature_file(self, mode):
),
)

@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("best_tfmr")
save_path.mkdir(exist_ok=True)
self.model.config.save_step = self.step_count
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
self.tfmr_ckpts[self.step_count] = save_path

@staticmethod
def add_model_specific_args(parser, root_dir):
parser.add_argument(
Expand All @@ -152,7 +194,7 @@ def add_model_specific_args(parser, root_dir):
)
parser.add_argument(
"--tokenizer_name",
default="",
default=None,
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
Expand All @@ -165,7 +207,7 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
Expand Down Expand Up @@ -199,7 +241,8 @@ def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
writer.write("{} = {}\n".format(key, str(metrics[key])))


def add_generic_args(parser, root_dir):
def add_generic_args(parser, root_dir) -> None:
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument(
"--output_dir",
default=None,
Expand All @@ -221,8 +264,8 @@ def add_generic_args(parser, root_dir):
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)

parser.add_argument("--n_gpu", type=int, default=1)
parser.add_argument("--fast_dev_run", action="store_true")
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument("--n_tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
Expand All @@ -235,28 +278,32 @@ def add_generic_args(parser, root_dir):
)

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")


def generic_train(model: BaseTransformer, args: argparse.Namespace):
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--val_check_interval", default=1.0, type=float)


def generic_train(
model: BaseTransformer,
args: argparse.Namespace,
early_stopping_callback=False,
logger=True, # can pass WandbLogger() here
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
**extra_train_kwargs
):
# init model
set_seed(args)
odir = Path(model.hparams.output_dir)
odir.mkdir(exist_ok=True)
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
)
if logging_callback is None:
logging_callback = LoggingCallback()

if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
)

train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
max_epochs=args.num_train_epochs,
early_stop_callback=False,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=[LoggingCallback()],
)
train_params = {}

if args.fp16:
train_params["use_amp"] = args.fp16
Expand All @@ -269,12 +316,27 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace):
train_params["num_tpu_cores"] = args.n_tpu_cores
train_params["gpus"] = 0

if args.n_gpu > 1:
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"

trainer = pl.Trainer(**train_params)
trainer = pl.Trainer(
logger=logger,
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.gpus,
max_epochs=args.num_train_epochs,
early_stop_callback=early_stopping_callback,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=[logging_callback] + extra_callbacks,
fast_dev_run=args.fast_dev_run,
val_check_interval=args.val_check_interval,
weights_summary=None,
resume_from_checkpoint=args.resume_from_checkpoint,
**train_params,
)

if args.do_train:
trainer.fit(model)

trainer.logger.log_hyperparams(args)
trainer.logger.save()
return trainer
3 changes: 2 additions & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ psutil
sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.7.3 # April 10, 2020 release
pytorch-lightning==0.7.6 # April 10, 2020 release
matplotlib
git-python==1.0.3
78 changes: 50 additions & 28 deletions examples/summarization/README.md
Original file line number Diff line number Diff line change
@@ -1,47 +1,69 @@
### Get CNN Data
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
### Data

CNN/DailyMail data
```bash
cd examples/summarization
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz
export CNN_DIR=${PWD}/cnn_dm
```

this should make a directory called cnn_dm/ with files like `test.source`.
To use your own data, copy that files format. Each article to be summarized is on its own line.

XSUM Data:
```bash
cd examples/summarization
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz
tar -xzvf xsum.tar.gz
export XSUM_DIR=${PWD}/xsum
```


### Evaluation

To create summaries for each article in dataset, run:
```bash
python evaluate_cnn.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
python run_eval.py <path_to_test.source> test_generations.txt <model-name> --score_path rouge_scores.txt
```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system.


### Training
Run/modify `finetune_bart.sh` or `finetune_t5.sh`
Run/modify `finetune.sh`

### Stanford CoreNLP Setup
```
ptb_tokenize () {
cat $1 | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > $2
}

sudo apt install openjdk-8-jre-headless
sudo apt-get install ant
wget http://nlp.stanford.edu/software/stanford-corenlp-full-2018-10-05.zip
unzip stanford-corenlp-full-2018-10-05.zip
cd stanford-corenlp-full-2018-10-05
export CLASSPATH=stanford-corenlp-3.9.2.jar:stanford-corenlp-3.9.2-models.jar
The following command should work on a 16GB GPU:
```bash
export me=`git config user.name`
./finetune.sh \
--data_dir $XSUM_DIR \
--train_batch_size=1 \
--eval_batch_size=1 \
--output_dir="$me"_xsum_results \
--num_train_epochs 1
```
Then run `ptb_tokenize` on `test.target` and your generated hypotheses.
### Rouge Setup
Install `files2rouge` following the instructions at [here](https://github.com/pltrdy/files2rouge).
I also needed to run `sudo apt-get install libxml-parser-perl`

```python
from files2rouge import files2rouge
from files2rouge import settings
files2rouge.run(<path_to_tokenized_hypo>,
<path_to_tokenized_target>,
saveto='rouge_output.txt')

Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16.
sshleifer marked this conversation as resolved.
Show resolved Hide resolved
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below)
- `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved.
Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`.
- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code.

### Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration.
Here is an example command
```bash
export me=`git config user.name`
./finetune.sh \
--data_dir $XSUM_DIR \
--output_dir "$me"_xsum_frozen_embs \
--logger wandb_shared \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6
```

Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-)
Loading