diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md new file mode 100644 index 000000000000..8b83501be390 --- /dev/null +++ b/examples/seq2seq/README.md @@ -0,0 +1,169 @@ +This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks. +Summarization support is more mature than translation support. +Please tag @sshleifer with any issues/unexpected behaviors, or send a PR! +For `bertabs` instructions, see `bertabs/README.md`. + +### Data + +CNN/DailyMail data +```bash +cd examples/seq2seq +wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz +tar -xzvf cnn_dm.tgz + +export CNN_DIR=${PWD}/cnn_dm +``` + +this should make a directory called cnn_dm/ with files like `test.source`. +To use your own data, copy that files format. Each article to be summarized is on its own line. + +XSUM Data: +```bash +cd examples/seq2seq +wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz +tar -xzvf xsum.tar.gz +export XSUM_DIR=${PWD}/xsum +``` + + +WMT16 English-Romanian Translation Data: +```bash +cd examples/seq2seq +wget https://s3.amazonaws.com/datasets.huggingface.co/translation/wmt_en_ro.tar.gz +tar -xzvf wmt_en_ro.tar.gz +export ENRO_DIR=${PWD}/wmt_en_ro +``` + +If you are using your own data, it must be formatted as one directory with 6 files: train.source, train.target, val.source, val.target, test.source, test.target. +The `.source` files are the input, the `.target` files are the desired output. + +### Evaluation + +To create summaries for each article in dataset, run: +```bash +python run_eval.py test_generations.txt --score_path rouge_scores.txt +``` +The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system. + + +### Summarization Finetuning +Run/modify `finetune.sh` + +The following command should work on a 16GB GPU: +```bash +./finetune.sh \ + --data_dir $XSUM_DIR \ + --train_batch_size=1 \ + --eval_batch_size=1 \ + --output_dir=xsum_results \ + --num_train_epochs 1 \ + --model_name_or_path facebook/bart-large +``` + +*Note*: The following tips mostly apply to summarization finetuning. + +Tips: +- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100. +- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below) +- `fp16_opt_level=O1` (the default works best). +- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. +(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). +- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved. +Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`. +- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code. +- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. +- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` +- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. +- `wandb` can be used by specifying `--logger wandb_shared` or `--logger wandb`. It is useful for reproducibility. +- This warning can be safely ignored: + > "Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-xsum and are newly initialized: ['final_logits_bias']" +- Both finetuning and eval are 30% faster with `--fp16`. For that you need to [install apex](https://github.com/NVIDIA/apex#quick-start). + +#### Finetuning Outputs +As you train, `output_dir` will be filled with files, that look kind of like this (comments are mine). +Some of them are metrics, some of them are checkpoints, some of them are metadata. Here is a quick tour: + +```bash +output_dir +├── best_tfmr # this is a huggingface checkpoint generated by save_pretrained. It is the same model as the PL .ckpt file below +│   ├── config.json +│   ├── merges.txt +│   ├── pytorch_model.bin +│   ├── special_tokens_map.json +│   ├── tokenizer_config.json +│   └── vocab.json +├── git_log.json # repo, branch, and commit hash +├── val_avg_rouge2=0.1984-step_count=11.ckpt # this is a pytorch lightning checkpoint associated with the best val score. +├── metrics.json # new validation metrics will continually be appended to this +├── student # this is a huggingface checkpoint generated by SummarizationDistiller. It is the student before it gets finetuned. +│   ├── config.json +│   └── pytorch_model.bin +├── test_generations.txt +# ^^ are the summaries or translations produced by your best checkpoint on the test data. Populated when training is done +├── test_results.txt # a convenience file with the test set metrics. This data is also in metrics.json['test'] +├── hparams.pkl # the command line args passed after some light preprocessing. Should be saved fairly quickly. +``` +After training, you can recover the best checkpoint by running +```python +from transformers import AutoModelForSeq2SeqLM +model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') +``` + + +### XSUM Shared Task +Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration. + +Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! +```bash +./finetune.sh \ + --data_dir $XSUM_DIR \ + --output_dir xsum_frozen_embs \ + --model_name_or_path facebook/bart-large \ + --logger wandb_shared \ + --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ + --num_train_epochs 6 \ + --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 +``` + +Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) + + +### Distilbart + +#### No Teacher Distillation +To run the simpler distilbart-cnn style distillation all you need is data, a GPU, and a properly initialized student. +You don't even need `distillation.py`. + +Some [un-finetuned students](https://huggingface.co/models?search=sshleifer%2Fstudent) are available for replication purposes. +They are initialized by copying layers from the associated `bart-large-{cnn|xsum}` teacher using `--init_strategy alternate`. (You can read about that in `initialization_utils.py`) +The command that produced `sshleifer/distilbart-cnn-12-6` is +```bash +./train_distilbart_cnn.sh +``` +runtime: 6H on NVIDIA RTX 24GB GPU + +*Note*: You can get the same simple distillation logic by using `./run_distiller.sh --no_teacher` followed by identical arguments as the ones in `train_distilbart_cnn.sh`. +If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent, +because you will have the same hyperparameters logged in every run. + +#### With a teacher +*Note* only BART variants are supported + +In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`. +This is how `sshleifer/distilbart-xsum*` checkpoints were produced. + +The command that produced `sshleifer/distilbart-xsum-12-6` is: + +```bash +./train_distilbart_xsum.sh +``` + +runtime: 13H on V-100 16GB GPU. + +### Contributing +- follow the standard contributing guidelines and code of conduct. +- add tests to `test_seq2seq_examples.py` +- To run only the seq2seq tests, you must be in the root of the repository and run: +```bash +pytest examples/seq2seq/ +``` diff --git a/examples/summarization/__init__.py b/examples/seq2seq/__init__.py similarity index 100% rename from examples/summarization/__init__.py rename to examples/seq2seq/__init__.py diff --git a/examples/summarization/bertabs/README.md b/examples/seq2seq/bertabs/README.md similarity index 99% rename from examples/summarization/bertabs/README.md rename to examples/seq2seq/bertabs/README.md index 1307de6b3f75..7835e8bc84ce 100644 --- a/examples/summarization/bertabs/README.md +++ b/examples/seq2seq/bertabs/README.md @@ -12,7 +12,7 @@ The model is loaded with the pre-trained weights for the abstractive summarizati git clone https://github.com/huggingface/transformers && cd transformers pip install . pip install nltk py-rouge -cd examples/summarization +cd examples/seq2seq/bertabs ``` ## Reproduce the authors' ROUGE score diff --git a/examples/summarization/bertabs/__init__.py b/examples/seq2seq/bertabs/__init__.py similarity index 100% rename from examples/summarization/bertabs/__init__.py rename to examples/seq2seq/bertabs/__init__.py diff --git a/examples/summarization/bertabs/configuration_bertabs.py b/examples/seq2seq/bertabs/configuration_bertabs.py similarity index 100% rename from examples/summarization/bertabs/configuration_bertabs.py rename to examples/seq2seq/bertabs/configuration_bertabs.py diff --git a/examples/summarization/bertabs/convert_bertabs_original_pytorch_checkpoint.py b/examples/seq2seq/bertabs/convert_bertabs_original_pytorch_checkpoint.py similarity index 100% rename from examples/summarization/bertabs/convert_bertabs_original_pytorch_checkpoint.py rename to examples/seq2seq/bertabs/convert_bertabs_original_pytorch_checkpoint.py diff --git a/examples/summarization/bertabs/modeling_bertabs.py b/examples/seq2seq/bertabs/modeling_bertabs.py similarity index 100% rename from examples/summarization/bertabs/modeling_bertabs.py rename to examples/seq2seq/bertabs/modeling_bertabs.py diff --git a/examples/summarization/bertabs/requirements.txt b/examples/seq2seq/bertabs/requirements.txt similarity index 100% rename from examples/summarization/bertabs/requirements.txt rename to examples/seq2seq/bertabs/requirements.txt diff --git a/examples/summarization/bertabs/run_summarization.py b/examples/seq2seq/bertabs/run_summarization.py similarity index 100% rename from examples/summarization/bertabs/run_summarization.py rename to examples/seq2seq/bertabs/run_summarization.py diff --git a/examples/summarization/bertabs/test_utils_summarization.py b/examples/seq2seq/bertabs/test_utils_summarization.py similarity index 100% rename from examples/summarization/bertabs/test_utils_summarization.py rename to examples/seq2seq/bertabs/test_utils_summarization.py diff --git a/examples/summarization/bertabs/utils_summarization.py b/examples/seq2seq/bertabs/utils_summarization.py similarity index 100% rename from examples/summarization/bertabs/utils_summarization.py rename to examples/seq2seq/bertabs/utils_summarization.py diff --git a/examples/summarization/callbacks.py b/examples/seq2seq/callbacks.py similarity index 74% rename from examples/summarization/callbacks.py rename to examples/seq2seq/callbacks.py index 83b54d08c7a6..523d18fcd7e9 100644 --- a/examples/summarization/callbacks.py +++ b/examples/seq2seq/callbacks.py @@ -32,9 +32,12 @@ def _write_logs( results_file = od / "test_results.txt" generations_file = od / "test_generations.txt" else: - results_file = od / f"{type_path}_results_{trainer.global_step:05d}.txt" - generations_file = od / f"{type_path}_generations_{trainer.global_step:05d}.txt" - + # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json + # If people want this it will be easy enough to add back. + results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" + generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" + results_file.parent.mkdir(exist_ok=True) + generations_file.parent.mkdir(exist_ok=True) with open(results_file, "a+") as writer: for key in sorted(metrics): if key in ["log", "progress_bar", "preds"]: @@ -63,20 +66,25 @@ def on_train_start(self, trainer, pl_module): # mp stands for million parameters trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) - @rank_zero_only - def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - return self._write_logs(trainer, pl_module, "val") - @rank_zero_only def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): return self._write_logs(trainer, pl_module, "test") -def get_rouge2_checkpoint_callback(output_dir): +def get_checkpoint_callback(output_dir, metric): """Saves the best model by validation ROUGE2 score.""" + if metric == "rouge2": + exp = "{val_avg_rouge2:.4f}-{step_count}" + elif metric == "bleu": + exp = "{val_avg_bleu:.4f}-{step_count}" + else: + raise NotImplementedError( + f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." + ) + checkpoint_callback = ModelCheckpoint( - filepath=os.path.join(output_dir, "{val_avg_rouge2:.4f}-{step_count}"), - monitor="val_rouge", + filepath=os.path.join(output_dir, exp), + monitor=f"val_{metric}", mode="max", save_top_k=1, period=0, # maybe save a checkpoint every time val is run, not just end of epoch. diff --git a/examples/summarization/distillation.py b/examples/seq2seq/distillation.py similarity index 91% rename from examples/summarization/distillation.py rename to examples/seq2seq/distillation.py index 290dde0518d6..0d645a14585f 100644 --- a/examples/summarization/distillation.py +++ b/examples/seq2seq/distillation.py @@ -39,13 +39,12 @@ ) -class SummarizationDistiller(SummarizationModule): +class BartSummarizationDistiller(SummarizationModule): loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"] def __init__(self, hparams): assert Path(hparams.data_dir).exists() - - d_layers_to_copy, student, student_cfg, teacher = self.pre_init(hparams) + student, student_cfg, teacher = self.pre_init(hparams) super().__init__(hparams, model=student, config=student_cfg) self.teacher = teacher @@ -73,12 +72,15 @@ def sanity_check_gradients(self): del self.teacher.model.encoder def pre_init(self, hparams): - # Dump empty student model at a path, then call from_pretrained on it + self.output_dir = Path(hparams.output_dir) + self.output_dir.mkdir(exist_ok=True) teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval() student_updates = { "decoder_layers": hparams.student_decoder_layers, "encoder_layers": hparams.student_encoder_layers, } + if hparams.length_penalty != -1: + student_updates["length_penalty"] = hparams.length_penalty d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) hparams.d_layer_to_copy = d_layers_to_copy @@ -89,9 +91,13 @@ def pre_init(self, hparams): student_cfg = BartConfig(**kw) student = BartForConditionalGeneration(student_cfg) student, _ = init_student(student, teacher) + save_dir = self.output_dir.joinpath("student") + save_dir.mkdir(exist_ok=True) + self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) - Path(hparams.output_dir).mkdir(exist_ok=True) - return d_layers_to_copy, student, student_cfg, teacher + student.save_pretrained(save_dir) + hparams.model_name_or_path = str(save_dir) + return student, student_cfg, teacher def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher): if teacher.config.model_type == "t5": @@ -154,7 +160,6 @@ def calc_ce_loss(self, mask, s_logits, t_logits): def configure_optimizers(self): "Prepare optimizer and schedule (linear warmup and decay)" - model = self.model no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ @@ -180,18 +185,11 @@ def add_model_specific_args(parser, root_dir): # parser.add_argument("--alpha_cos", default=0.0, type=float) parser.add_argument("--alpha_encoder_loss", default=0.0, type=float) parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) - parser.add_argument( - "--student_decoder_layers", default=12, type=int, required=False, - ) - parser.add_argument( - "--student_encoder_layers", default=12, type=int, required=False, - ) - parser.add_argument( - "--no_teacher", action="store_true", default=False, - ) - parser.add_argument( # TODO: remove - "--enc_only", action="store_true", default=False, - ) + parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) + parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) + parser.add_argument("--no_teacher", action="store_true", default=False) + parser.add_argument("--length_penalty", type=float, default=-1) + return parser def _step(self, batch): @@ -269,12 +267,14 @@ def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, match return sum(hidden_losses) -class T5SummarizationDistiller(SummarizationDistiller): +class T5SummarizationDistiller(BartSummarizationDistiller): def pre_init(self, hparams): raise NotImplementedError("T5 Distillation does not work yet") + self.output_dir = Path(hparams.output_dir) + self.output_dir.mkdir(exist_ok=True) teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher) n_layer = hparams.student_decoder_layers - assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this + assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6. d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block)) e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block)) student_updates = {"num_layers": n_layer} @@ -291,8 +291,13 @@ def pre_init(self, hparams): Path(hparams.output_dir).mkdir(exist_ok=True) task_specific_params = student.config.task_specific_params if task_specific_params is not None: - student.config.update(task_specific_params.get("summarization", {})) - return d_layers_to_copy, student, student_cfg, teacher + student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode + save_dir = self.output_dir.joinpath("student") + save_dir.mkdir(exist_ok=True) + + student.save_pretrained(save_dir) + hparams.model_name_or_path = str(save_dir) + return student, student_cfg, teacher def freeze_embeds(self): freeze_params(self.model.shared) @@ -386,7 +391,7 @@ def create_module(args): elif args.enc_only: raise ValueError("Deleted that") else: - module_cls = SummarizationDistiller + module_cls = BartSummarizationDistiller args.setup_cls: str = module_cls.__name__ model = module_cls(args) return model @@ -418,18 +423,18 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): def get_layers_to_copy(n_to_get, tot): all_layers = list(range(tot)) if tot == 12: # Alternating for special cases - layers_to_copy = { # maps # layers in student -> which teacher layers to copy - 6: [0, 2, 4, 7, 9, 11], - 1: [11], + layers_to_copy = { # maps num layers in student -> which teacher layers to copy + 1: [0], + 2: [0, 6], 3: [0, 6, 11], - 2: [0, 11], 4: [0, 4, 8, 11], + 6: [0, 2, 4, 7, 9, 11], 9: [0, 1, 2, 4, 5, 7, 9, 10, 11], 12: all_layers, } return layers_to_copy[n_to_get] else: - return all_layers[:n_to_get] + return all_layers[:n_to_get] # TODO: better version on theseus-bart branch def distill_main(args): @@ -443,7 +448,7 @@ def distill_main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd()) + parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd()) args = parser.parse_args() distill_main(args) diff --git a/examples/summarization/finetune.py b/examples/seq2seq/finetune.py similarity index 86% rename from examples/summarization/finetune.py rename to examples/seq2seq/finetune.py index a10d3f6511d8..98ebaf3f89c2 100644 --- a/examples/summarization/finetune.py +++ b/examples/seq2seq/finetune.py @@ -3,6 +3,7 @@ import logging import os import time +from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple @@ -23,12 +24,14 @@ flatten_list, pickle_save, save_git_info, + save_json, freeze_params, calculate_rouge, get_git_info, ROUGE_KEYS, + calculate_bleu_score, ) - from .callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback + from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback except ImportError: from utils import ( use_task_specific_params, @@ -37,12 +40,14 @@ flatten_list, pickle_save, save_git_info, + save_json, freeze_params, calculate_rouge, get_git_info, ROUGE_KEYS, + calculate_bleu_score, ) - from callbacks import Seq2SeqLoggingCallback, get_rouge2_checkpoint_callback + from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback logger = logging.getLogger(__name__) @@ -50,15 +55,18 @@ class SummarizationModule(BaseTransformer): mode = "summarization" loss_names = ["loss"] + metric_names = ROUGE_KEYS + val_metric = "rouge2" def __init__(self, hparams, **kwargs): super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) use_task_specific_params(self.model, "summarization") save_git_info(self.hparams.output_dir) - self.metrics_save_path = Path(self.output_dir) / "metrics.pkl" + self.metrics_save_path = Path(self.output_dir) / "metrics.json" self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" + pickle_save(self.hparams, self.hparams_save_path) self.step_count = 0 - self.metrics = {"train": [], "val": [], "test": []} + self.metrics = defaultdict(list) self.dataset_kwargs: dict = dict( data_dir=self.hparams.data_dir, @@ -89,12 +97,12 @@ def __init__(self, hparams, **kwargs): def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" - if self.model.config.model_type == "bart": + try: freeze_params(self.model.model.shared) for d in [self.model.model.encoder, self.model.model.decoder]: freeze_params(d.embed_positions) freeze_params(d.embed_tokens) - else: + except AttributeError: freeze_params(self.model.shared) for d in [self.model.encoder, self.model.decoder]: freeze_params(d.embed_tokens) @@ -130,19 +138,22 @@ def validation_epoch_end(self, outputs, prefix="val") -> Dict: self.step_count += 1 losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} loss = losses["loss"] - rouges = {k: np.array([x[k] for x in outputs]).mean() for k in ROUGE_KEYS + ["gen_time", "summ_len"]} - rouge_tensor: torch.FloatTensor = torch.tensor(rouges["rouge2"]).type_as(loss) + rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "summ_len"]} + rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) rouges.update({k: v.item() for k, v in losses.items()}) losses.update(rouges) metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} metrics["step_count"] = self.step_count self.save_metrics(metrics, prefix) # writes to self.metrics_save_path preds = flatten_list([x["preds"] for x in outputs]) - return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_rouge": rouge_tensor} + return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor} + + def save_metrics(self, latest_metrics, type_path) -> None: + self.metrics[type_path].append(latest_metrics) + save_json(self.metrics, self.metrics_save_path) - def save_metrics(self, metrics, prefix) -> None: - self.metrics[prefix].append(metrics) - pickle_save(self.metrics, self.metrics_save_path) + def calc_generative_metrics(self, preds, target) -> Dict: + return calculate_rouge(preds, target) def _generative_step(self, batch: dict) -> dict: pad_token_id = self.tokenizer.pad_token_id @@ -154,7 +165,7 @@ def _generative_step(self, batch: dict) -> dict: target = self.ids_to_clean_text(y) loss_tensors = self._step(batch) base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} - rouge: Dict = calculate_rouge(preds, target) + rouge: Dict = self.calc_generative_metrics(preds, target) summ_len = np.mean(lmap(len, generated_ids)) base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge) return base_metrics @@ -259,15 +270,33 @@ def add_model_specific_args(parser, root_dir): parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") + parser.add_argument( + "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." + ) return parser +class TranslationModule(SummarizationModule): + mode = "translation" + loss_names = ["loss"] + metric_names = ["bleu"] + val_metric = "bleu" + + def calc_generative_metrics(self, preds, target) -> dict: + return calculate_bleu_score(preds, target) + + def main(args, model=None) -> SummarizationModule: Path(args.output_dir).mkdir(exist_ok=True) if len(os.listdir(args.output_dir)) > 3 and args.do_train: raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if model is None: - model: BaseTransformer = SummarizationModule(args) + if args.task == "summarization": + model: SummarizationModule = SummarizationModule(args) + else: + model: SummarizationModule = TranslationModule(args) + + dataset = Path(args.data_dir).name if ( args.logger == "default" or args.fast_dev_run @@ -278,17 +307,17 @@ def main(args, model=None) -> SummarizationModule: elif args.logger == "wandb": from pytorch_lightning.loggers import WandbLogger - logger = WandbLogger(name=model.output_dir.name) + logger = WandbLogger(name=model.output_dir.name, project=dataset) + elif args.logger == "wandb_shared": from pytorch_lightning.loggers import WandbLogger - # TODO: separate LB for CNN, we should use Path(args.data_dir).name to determine the correct LB. - logger = WandbLogger(name=model.output_dir.name, project="hf_summarization") + logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") trainer: pl.Trainer = generic_train( model, args, logging_callback=Seq2SeqLoggingCallback(), - checkpoint_callback=get_rouge2_checkpoint_callback(args.output_dir), + checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), logger=logger, # TODO: early stopping callback seems messed up ) diff --git a/examples/summarization/finetune.sh b/examples/seq2seq/finetune.sh similarity index 64% rename from examples/summarization/finetune.sh rename to examples/seq2seq/finetune.sh index b3ee8c90667f..89bd68eaae74 100755 --- a/examples/summarization/finetune.sh +++ b/examples/seq2seq/finetune.sh @@ -1,13 +1,8 @@ - # Add parent directory to python path to access lightning_base.py export PYTHONPATH="../":"${PYTHONPATH}" - -# --model_name_or_path=t5-base for t5 - -# the proper usage is documented in the README +# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path python finetune.py \ - --model_name_or_path=facebook/bart-large \ --learning_rate=3e-5 \ --fp16 \ --gpus 1 \ @@ -16,5 +11,4 @@ python finetune.py \ --n_val 1000 \ --val_check_interval 0.1 \ --sortish_sampler \ - --max_target_length=56 \ $@ diff --git a/examples/summarization/finetune_bart_tiny.sh b/examples/seq2seq/finetune_bart_tiny.sh old mode 100644 new mode 100755 similarity index 100% rename from examples/summarization/finetune_bart_tiny.sh rename to examples/seq2seq/finetune_bart_tiny.sh diff --git a/examples/summarization/finetune_t5.sh b/examples/seq2seq/finetune_t5.sh old mode 100644 new mode 100755 similarity index 100% rename from examples/summarization/finetune_t5.sh rename to examples/seq2seq/finetune_t5.sh diff --git a/examples/summarization/initialization_utils.py b/examples/seq2seq/initialization_utils.py similarity index 100% rename from examples/summarization/initialization_utils.py rename to examples/seq2seq/initialization_utils.py diff --git a/examples/summarization/run_distiller.sh b/examples/seq2seq/run_distiller.sh similarity index 73% rename from examples/summarization/run_distiller.sh rename to examples/seq2seq/run_distiller.sh index 6fbecad388c2..61456490fd48 100755 --- a/examples/summarization/run_distiller.sh +++ b/examples/seq2seq/run_distiller.sh @@ -1,5 +1,3 @@ -#CNN_DIR = /home/shleifer/transformers_fork/examples/summarization/bart/cnn_dm - # Add parent directory to python path to access lightning_base.py export PYTHONPATH="../":"${PYTHONPATH}" diff --git a/examples/summarization/run_eval.py b/examples/seq2seq/run_eval.py similarity index 67% rename from examples/summarization/run_eval.py rename to examples/seq2seq/run_eval.py index 0bbaf9d64adc..82699d1f2265 100644 --- a/examples/summarization/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -9,9 +9,9 @@ try: - from .finetune import calculate_rouge, use_task_specific_params + from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score except ImportError: - from finetune import calculate_rouge, use_task_specific_params + from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @@ -22,8 +22,14 @@ def chunks(lst, n): yield lst[i : i + n] -def generate_summaries( - examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False, +def generate_summaries_or_translations( + examples: list, + out_file: str, + model_name: str, + batch_size: int = 8, + device: str = DEFAULT_DEVICE, + fp16=False, + **gen_kwargs, ) -> None: fout = Path(out_file).open("w", encoding="utf-8") model_name = str(model_name) @@ -39,11 +45,10 @@ def generate_summaries( for batch in tqdm(list(chunks(examples, batch_size))): if "t5" in model_name: batch = [model.config.prefix + text for text in batch] - dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True).to( - device - ) - summaries = model.generate(**dct) - + batch = tokenizer.batch_encode_plus( + batch, max_length=1024, return_tensors="pt", truncation=True, pad_to_max_length=True + ).to(device) + summaries = model.generate(**batch, **gen_kwargs) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) for hypothesis in dec: fout.write(hypothesis + "\n") @@ -57,22 +62,26 @@ def run_generate(): parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt") parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format") + parser.add_argument("--metric", type=str, choices=["bleu", "rouge"], default="rouge") parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument("--fp16", action="store_true") args = parser.parse_args() examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()] - generate_summaries( + generate_summaries_or_translations( examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device, fp16=args.fp16 ) - if args.score_path is not None: - output_lns = [x.rstrip() for x in open(args.output_path).readlines()] - reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] - rouge: dict = calculate_rouge(output_lns, reference_lns) - - json.dump(rouge, open("score_path", "w+")) + output_lns = [x.rstrip() for x in open(args.output_path).readlines()] + scores = {} + if args.reference_path is not None: + score_fn = {"bleu": calculate_bleu_score, "rouge": calculate_rouge}[args.metric] + reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()] + scores: dict = score_fn(output_lns, reference_lns) + if args.score_path is not None: + json.dump(scores, open("score_path", "w+")) + return scores if __name__ == "__main__": diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py new file mode 100644 index 000000000000..6fb1144b2ec8 --- /dev/null +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -0,0 +1,252 @@ +import argparse +import logging +import os +import sys +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +import pytest +import torch +from torch.utils.data import DataLoader + +from transformers import AutoTokenizer + +from .distillation import distill_main, evaluate_checkpoint +from .finetune import main +from .run_eval import generate_summaries_or_translations, run_generate +from .utils import SummarizationDataset, lmap, load_json + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +CUDA_AVAILABLE = torch.cuda.is_available() +CHEAP_ARGS = { + "logger": "default", + "length_penalty": 0.5, + "cache_dir": "", + "task": "summarization", + "num_workers": 2, + "alpha_hid": 0, + "freeze_embeds": True, + "enc_only": False, + "tgt_suffix": "", + "resume_from_checkpoint": None, + "sortish_sampler": True, + "student_decoder_layers": 1, + "val_check_interval": 1.0, + "output_dir": "", + "fp16": CUDA_AVAILABLE, + "no_teacher": False, + "fp16_opt_level": "O1", + "gpus": 1 if CUDA_AVAILABLE else 0, + "n_tpu_cores": 0, + "max_grad_norm": 1.0, + "do_train": True, + "do_predict": True, + "gradient_accumulation_steps": 1, + "server_ip": "", + "server_port": "", + "seed": 42, + "model_name_or_path": "sshleifer/bart-tiny-random", + "config_name": "", + "tokenizer_name": "facebook/bart-large", + "do_lower_case": False, + "learning_rate": 0.3, + "weight_decay": 0.0, + "adam_epsilon": 1e-08, + "warmup_steps": 0, + "num_train_epochs": 1, + "train_batch_size": 2, + "eval_batch_size": 2, + "max_source_length": 12, + "max_target_length": 12, + "val_max_target_length": 12, + "test_max_target_length": 12, + "fast_dev_run": False, + "no_cache": False, + "n_train": -1, + "n_val": -1, + "n_test": -1, + "student_encoder_layers": 1, + "alpha_loss_encoder": 0.0, + "freeze_encoder": False, + "auto_scale_batch_size": False, +} + + +def _dump_articles(path: Path, articles: list): + with path.open("w") as f: + f.write("\n".join(articles)) + + +ARTICLES = [" Sam ate lunch today", "Sams lunch ingredients"] +SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] +T5_TINY = "patrickvonplaten/t5-tiny-random" +BART_TINY = "sshleifer/bart-tiny-random" +MBART_TINY = "sshleifer/tiny-mbart" +MARIAN_TINY = "sshleifer/tiny-marian-en-de" +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) +logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks + + +def make_test_data_dir(**kwargs): + tmp_dir = Path(tempfile.mkdtemp(**kwargs)) + for split in ["train", "val", "test"]: + _dump_articles((tmp_dir / f"{split}.source"), ARTICLES) + _dump_articles((tmp_dir / f"{split}.target"), SUMMARIES) + return tmp_dir + + +class TestSummarizationDistiller(unittest.TestCase): + @classmethod + def setUpClass(cls): + logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks + return cls + + @unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test") + def test_multigpu(self): + updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,) + self._test_distiller_cli(updates) + + def test_distill_no_teacher(self): + updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) + self._test_distiller_cli(updates) + + def test_distill_checkpointing_with_teacher(self): + updates = dict( + student_encoder_layers=2, + student_decoder_layers=1, + num_train_epochs=4, + val_check_interval=0.25, + alpha_hid=2.0, + model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED", + ) + model = self._test_distiller_cli(updates, check_contents=False) + + ckpts = list(Path(model.output_dir).glob("*.ckpt")) + self.assertEqual(1, len(ckpts)) + transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) + self.assertEqual(len(transformer_ckpts), 2) + examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) + out_path = tempfile.mktemp() + generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr")) + self.assertTrue(Path(out_path).exists()) + + evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) + + @unittest.skip("T5 distillation is broken at the moment") + def test_distill_t5(self): + updates = dict( + student_encoder_layers=1, + student_decoder_layers=1, + alpha_hid=2.0, + teacher=T5_TINY, + model_name_or_path=T5_TINY, + tokenizer_name=T5_TINY, + ) + self._test_distiller_cli(updates) + + def _test_distiller_cli(self, updates, check_contents=True): + default_updates = dict( + train_batch_size=1, + eval_batch_size=2, + num_train_epochs=2, + alpha_mlm=0.2, + alpha_ce=0.8, + do_predict=True, + model_name_or_path="sshleifer/tinier_bart", + teacher=CHEAP_ARGS["model_name_or_path"], + val_check_interval=0.5, + alpha_encoder_loss=0.4, + ) + default_updates.update(updates) + args_d: dict = CHEAP_ARGS.copy() + tmp_dir = make_test_data_dir() + output_dir = tempfile.mkdtemp(prefix="output_") + + args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) + model = distill_main(argparse.Namespace(**args_d)) + if not check_contents: + return model + contents = os.listdir(output_dir) + ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt" + contents = {os.path.basename(p) for p in contents} + self.assertIn(ckpt_name, contents) + + self.assertIn("test_generations.txt", contents) + self.assertIn("test_results.txt", contents) + + metrics = load_json(model.metrics_save_path) + last_step_stats = metrics["val"][-1] + self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01) + self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"]) + self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float) + desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1) + self.assertEqual(len(metrics["val"]), desired_n_evals) + self.assertEqual(len(metrics["test"]), 1) + return model + + +@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) +def test_run_eval_bart(model): + tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" + + output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" + assert not output_file_name.exists() + articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] + _dump_articles(tmp, articles) + testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path + with patch.object(sys, "argv", testargs): + run_generate() + assert Path(output_file_name).exists() + os.remove(Path(output_file_name)) + + +@pytest.mark.parametrize( + ["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)] +) +def test_finetune(model): + args_d: dict = CHEAP_ARGS.copy() + task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization" + tmp_dir = make_test_data_dir() + output_dir = tempfile.mkdtemp(prefix="output_") + args_d.update( + data_dir=tmp_dir, + model_name_or_path=model, + tokenizer_name=None, + train_batch_size=2, + eval_batch_size=2, + output_dir=output_dir, + do_predict=True, + task=task, + ) + assert "n_train" in args_d + args = argparse.Namespace(**args_d) + main(args) + + +@pytest.mark.parametrize( + ["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)] +) +def test_dataset(tok): + tokenizer = AutoTokenizer.from_pretrained(tok) + tmp_dir = make_test_data_dir() + max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) + max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) + trunc_target = 4 + train_dataset = SummarizationDataset( + tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, + ) + dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) + for batch in dataloader: + assert batch["attention_mask"].shape == batch["input_ids"].shape + # show that articles were trimmed. + assert batch["input_ids"].shape[1] == max_len_source + assert 20 >= batch["input_ids"].shape[1] # trimmed significantly + # show that targets were truncated + assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated + assert max_len_target > trunc_target # Truncated diff --git a/examples/seq2seq/train_distilbart_cnn.sh b/examples/seq2seq/train_distilbart_cnn.sh new file mode 100755 index 000000000000..7ab25e34e955 --- /dev/null +++ b/examples/seq2seq/train_distilbart_cnn.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +export PYTHONPATH="../":"${PYTHONPATH}" + +export BS=32 +export GAS=1 + +python finetune.py \ + --learning_rate=3e-5 \ + --fp16 \ + --gpus 1 \ + --do_train \ + --do_predict \ + --val_check_interval 0.25 \ + --n_val 500 \ + --num_train_epochs 2 \ + --freeze_encoder --freeze_embeds --data_dir $CNN_DIR \ + --max_target_length 142 --val_max_target_length=142 \ + --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \ + --data_dir $CNN_DIR \ + --model_name_or_path sshleifer/student_cnn_12_6 \ + --tokenizer_name facebook/bart-large \ + --output_dir distilbart-cnn-12-6 \ + $@ + diff --git a/examples/seq2seq/train_distilbart_xsum.sh b/examples/seq2seq/train_distilbart_xsum.sh new file mode 100755 index 000000000000..65f6ec8eb16a --- /dev/null +++ b/examples/seq2seq/train_distilbart_xsum.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +export PYTHONPATH="../":"${PYTHONPATH}" +export BS=16 +export GAS=2 +python distillation.py \ + --learning_rate=3e-4 \ + --do_train \ + --do_predict \ + --fp16 \ + --val_check_interval 0.1 --n_val 1000 \ + --teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \ + --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ + --student_decoder_layers 6 --student_encoder_layers 12 \ + --freeze_encoder --freeze_embeds \ + --model_name_or_path IGNORED \ + --alpha_hid=3. --length_penalty=0.5 \ + --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \ + --tokenizer_name facebook/bart-large \ + --output_dir distilbart_xsum_12_6 \ + $@ diff --git a/examples/summarization/utils.py b/examples/seq2seq/utils.py similarity index 91% rename from examples/summarization/utils.py rename to examples/seq2seq/utils.py index bff6c3de3ed6..39cfa9d38056 100644 --- a/examples/summarization/utils.py +++ b/examples/seq2seq/utils.py @@ -3,12 +3,13 @@ import os import pickle from pathlib import Path -from typing import Dict, Iterable, List +from typing import Callable, Dict, Iterable, List import git import numpy as np import torch from rouge_score import rouge_scorer, scoring +from sacrebleu import corpus_bleu from torch import nn from torch.utils.data import Dataset, Sampler from tqdm import tqdm @@ -41,7 +42,7 @@ def encode_file( examples = [] for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"): tokenized = tokenizer.batch_encode_plus( - [text], # DONT ADD SPACES + [text], max_length=max_length, pad_to_max_length=pad_to_max_length, add_prefix_space=True, @@ -54,11 +55,13 @@ def encode_file( return examples -def lmap(f, x): +def lmap(f: Callable, x: Iterable) -> List: + """list(map(f, x))""" return list(map(f, x)) -T5_PREFIX = "summarize: " # HACK, fixme +def calculate_bleu_score(output_lns, refs_lns) -> dict: + return {"bleu": corpus_bleu(output_lns, [refs_lns]).score} def trim_batch( @@ -95,6 +98,8 @@ def __init__( tok_name=tok_name, ) tgt_path = os.path.join(data_dir, type_path + ".target") + if hasattr(tokenizer, "set_lang"): + tokenizer.set_lang("ro_RO") # HACK: only applies to mbart self.target = encode_file( tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name ) @@ -189,14 +194,20 @@ def flatten_list(summary_ids: List[List]): return [x for x in itertools.chain.from_iterable(summary_ids)] -def save_git_info(folder_path: str): - """ - Log commit info. - """ +def save_git_info(folder_path: str) -> None: + """Save git information to output_dir/git_log.json""" repo_infos = get_git_info() + save_json(repo_infos, os.path.join(folder_path, "git_log.json")) - with open(os.path.join(folder_path, "git_log.json"), "w") as f: - json.dump(repo_infos, f, indent=4) + +def save_json(content, path): + with open(path, "w") as f: + json.dump(content, f, indent=4) + + +def load_json(path): + with open(path) as f: + return json.load(f) def get_git_info(): diff --git a/examples/summarization/README.md b/examples/summarization/README.md deleted file mode 100644 index b626333cba8f..000000000000 --- a/examples/summarization/README.md +++ /dev/null @@ -1,70 +0,0 @@ -### Data - -CNN/DailyMail data -```bash -cd examples/summarization -wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz -tar -xzvf cnn_dm.tgz -export CNN_DIR=${PWD}/cnn_dm -``` - -this should make a directory called cnn_dm/ with files like `test.source`. -To use your own data, copy that files format. Each article to be summarized is on its own line. - -XSUM Data: -```bash -cd examples/summarization -wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/xsum.tar.gz -tar -xzvf xsum.tar.gz -export XSUM_DIR=${PWD}/xsum -``` - - -### Evaluation - -To create summaries for each article in dataset, run: -```bash -python run_eval.py test_generations.txt --score_path rouge_scores.txt -``` -The default batch size, 4, fits in 16GB GPU memory, but may need to be adjusted to fit your system. - - -### Training -Run/modify `finetune.sh` - -The following command should work on a 16GB GPU: -```bash -export me=`git config user.name` -./finetune.sh \ - --data_dir $XSUM_DIR \ - --train_batch_size=1 \ - --eval_batch_size=1 \ - --output_dir="$me"_xsum_results \ - --num_train_epochs 1 -``` - -Tips: -- 1 epoch at batch size 1 for bart-large takes 24 hours, requires 13GB GPU RAM with fp16 on an NVIDIA-V100. -- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see below) -- `fp16_opt_level=O1` (the default works best). -- If you are finetuning on your own dataset, start from `bart-large-cnn` if you want long summaries and `bart-large-xsum` if you want short summaries. -(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). -- In addition to the pytorch-lightning .ckpt checkpoint, a transformers checkpoint will be saved. -Load it with `BartForConditionalGeneration.from_pretrained(f'{output_dir}/best_tfmr)`. -- At the moment, `--do_predict` does not work in a multi-gpu setting. You need to use `evaluate_checkpoint` or the `run_eval.py` code. -- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. - -### XSUM Shared Task -Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration. -Here is an example command -```bash -export me=`git config user.name` -./finetune.sh \ - --data_dir $XSUM_DIR \ - --output_dir "$me"_xsum_frozen_embs \ - --logger wandb_shared \ - --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ - --num_train_epochs 6 -``` - -Results can be viewed [here](https://app.wandb.ai/sshleifer/hf_summarization/table?workspace=user-) diff --git a/examples/summarization/test_summarization_examples.py b/examples/summarization/test_summarization_examples.py deleted file mode 100644 index d829793ce183..000000000000 --- a/examples/summarization/test_summarization_examples.py +++ /dev/null @@ -1,267 +0,0 @@ -import argparse -import logging -import os -import sys -import tempfile -import unittest -from pathlib import Path -from unittest.mock import patch - -import torch -from torch.utils.data import DataLoader - -from transformers import BartTokenizer - -from .distillation import distill_main, evaluate_checkpoint -from .finetune import main -from .run_eval import generate_summaries, run_generate -from .utils import SummarizationDataset, lmap, pickle_load - - -logging.basicConfig(level=logging.DEBUG) - -logger = logging.getLogger() -FP16_EVER = False -CHEAP_ARGS = { - "logger": "default", - "num_workers": 2, - "alpha_hid": 0, - "freeze_embeds": True, - "enc_only": False, - "tgt_suffix": "", - "resume_from_checkpoint": None, - "sortish_sampler": True, - "student_decoder_layers": 1, - "val_check_interval": 1.0, - "output_dir": "", - "fp16": False, - "no_teacher": False, - "fp16_opt_level": "O1", - "gpus": 1 if torch.cuda.is_available() else 0, - "n_tpu_cores": 0, - "max_grad_norm": 1.0, - "do_train": True, - "do_predict": True, - "gradient_accumulation_steps": 1, - "server_ip": "", - "server_port": "", - "seed": 42, - "model_type": "bart", - "model_name_or_path": "sshleifer/bart-tiny-random", - "config_name": "", - "tokenizer_name": "facebook/bart-large", - "cache_dir": "", - "do_lower_case": False, - "learning_rate": 3e-05, - "weight_decay": 0.0, - "adam_epsilon": 1e-08, - "warmup_steps": 0, - "num_train_epochs": 1, - "train_batch_size": 2, - "eval_batch_size": 2, - "max_source_length": 12, - "max_target_length": 12, - "val_max_target_length": 12, - "test_max_target_length": 12, - "fast_dev_run": False, - "no_cache": False, - "n_train": -1, - "n_val": -1, - "n_test": -1, - "student_encoder_layers": 1, - "alpha_loss_encoder": 0.0, - "freeze_encoder": False, - "auto_scale_batch_size": False, -} - - -def _dump_articles(path: Path, articles: list): - with path.open("w") as f: - f.write("\n".join(articles)) - - -MSG = "T5 is broken at the moment" -T5_TINY = "patrickvonplaten/t5-tiny-random" - - -def make_test_data_dir(): - tmp_dir = Path(tempfile.gettempdir()) - articles = [" Sam ate lunch today", "Sams lunch ingredients"] - summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] - for split in ["train", "val", "test"]: - _dump_articles((tmp_dir / f"{split}.source"), articles) - _dump_articles((tmp_dir / f"{split}.target"), summaries) - return tmp_dir - - -class TestSummarizationDistiller(unittest.TestCase): - @classmethod - def setUpClass(cls): - logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks - return cls - - @unittest.skipUnless(torch.cuda.device_count() > 1, "skipping multiGPU test") - def test_bdc_multigpu(self): - updates = dict( - student_encoder_layers=2, - student_decoder_layers=1, - no_teacher=True, - freeze_encoder=True, - gpus=2, - sortish_sampler=False, - fp16_opt_level="O1", - fp16=FP16_EVER, - ) - self._bart_distiller_cli(updates) - - def test_bdc_t5_train(self): - updates = dict( - fp16=FP16_EVER, - gpus=1 if torch.cuda.is_available() else 0, - model_type="t5", - model_name_or_path=T5_TINY, - do_train=True, - do_predict=True, - tokenizer_name=T5_TINY, - no_teacher=True, - alpha_hid=2.0, - ) - self._bart_distiller_cli(updates) - - def test_bdc_no_teacher(self): - updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True,) - self._bart_distiller_cli(updates) - - def test_bdc_yes_teacher(self): - updates = dict(student_encoder_layers=2, student_decoder_layers=1,) - self._bart_distiller_cli(updates) - - def test_bdc_checkpointing(self): - updates = dict( - student_encoder_layers=2, - student_decoder_layers=1, - num_train_epochs=4, - val_check_interval=0.25, - alpha_hid=2.0, - ) - model = self._bart_distiller_cli(updates, check_contents=False) - - ckpts = list(Path(model.output_dir).glob("*.ckpt")) - self.assertEqual(1, len(ckpts)) - transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) - self.assertEqual(len(transformer_ckpts), len(ckpts)) - new_transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) - self.assertEqual(len(new_transformer_ckpts), 1) - examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) - out_path = tempfile.mktemp() - generate_summaries(examples, out_path, new_transformer_ckpts[0].parent) - self.assertTrue(Path(out_path).exists()) - - evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) - - def _bart_distiller_cli(self, updates, check_contents=True): - default_updates = dict( - train_batch_size=1, - eval_batch_size=2, - num_train_epochs=2, - alpha_mlm=0.2, - alpha_ce=0.8, - do_predict=True, - gpus=1 if torch.cuda.is_available() else 0, - model_name_or_path="sshleifer/tinier_bart", - teacher=CHEAP_ARGS["model_name_or_path"], - val_check_interval=0.5, - alpha_encoder_loss=0.4, - ) - default_updates.update(updates) - args_d: dict = CHEAP_ARGS.copy() - tmp_dir = make_test_data_dir() - output_dir = tempfile.mkdtemp(prefix="output_") - - args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) - model = distill_main(argparse.Namespace(**args_d)) - if not check_contents: - return model - contents = os.listdir(output_dir) - ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt" - contents = {os.path.basename(p) for p in contents} - self.assertIn(ckpt_name, contents) - self.assertIn("metrics.pkl", contents) - self.assertIn("test_generations.txt", contents) - self.assertIn("val_generations_00001.txt", contents) - self.assertIn("val_results_00001.txt", contents) - self.assertIn("test_results.txt", contents) - - metrics = pickle_load(Path(output_dir) / "metrics.pkl") - desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1) - self.assertEqual(len(metrics["val"]), desired_n_evals) - self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here - return model - - -class TestBartExamples(unittest.TestCase): - @classmethod - def setUpClass(cls): - stream_handler = logging.StreamHandler(sys.stdout) - logger.addHandler(stream_handler) - logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks - return cls - - def test_bart_cnn_cli(self): - tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" - output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo" - articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] - _dump_articles(tmp, articles) - testargs = ["run_eval.py", str(tmp), str(output_file_name), "sshleifer/bart-tiny-random"] - with patch.object(sys, "argv", testargs): - run_generate() - self.assertTrue(Path(output_file_name).exists()) - os.remove(Path(output_file_name)) - - def test_t5_run_sum_cli(self): - args_d: dict = CHEAP_ARGS.copy() - - tmp_dir = make_test_data_dir() - output_dir = tempfile.mkdtemp(prefix="output_") - args_d.update( - data_dir=tmp_dir, - model_name_or_path=T5_TINY, - tokenizer_name=None, # T5_TINY, - train_batch_size=2, - eval_batch_size=2, - gpus=0, - output_dir=output_dir, - do_predict=True, - ) - assert "n_train" in args_d - args = argparse.Namespace(**args_d) - main(args) - - def test_bart_summarization_dataset(self): - tmp_dir = Path(tempfile.gettempdir()) - articles = [" Sam ate lunch today", "Sams lunch ingredients"] - summaries = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"] - _dump_articles((tmp_dir / "train.source"), articles) - _dump_articles((tmp_dir / "train.target"), summaries) - tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") - max_len_source = max(len(tokenizer.encode(a)) for a in articles) - max_len_target = max(len(tokenizer.encode(a)) for a in summaries) - trunc_target = 4 - train_dataset = SummarizationDataset( - tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target, - ) - dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn) - for batch in dataloader: - self.assertEqual(batch["attention_mask"].shape, batch["input_ids"].shape) - # show that articles were trimmed. - self.assertEqual(batch["input_ids"].shape[1], max_len_source) - self.assertGreater(20, batch["input_ids"].shape[1]) # trimmed significantly - - # show that targets were truncated - self.assertEqual(batch["decoder_input_ids"].shape[1], trunc_target) # Truncated - self.assertGreater(max_len_target, trunc_target) # Truncated - - -def list_to_text_file(lst, path): - dest = Path(path) - dest.open("w+").writelines(lst) diff --git a/examples/translation/t5/README.md b/examples/translation/t5/README.md deleted file mode 100644 index 7abcfb8a85e3..000000000000 --- a/examples/translation/t5/README.md +++ /dev/null @@ -1,51 +0,0 @@ -***This script evaluates the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the English to German WMT dataset. Please note that the results in the paper were attained using a model fine-tuned on translation, so that results will be worse here by approx. 1.5 BLEU points*** - -### Intro - -This example shows how T5 (here the official [paper](https://arxiv.org/abs/1910.10683)) can be -evaluated on the WMT English-German dataset. - -### Get the WMT Data - -To be able to reproduce the authors' results on WMT English to German, you first need to download -the WMT14 en-de news datasets. -Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2014.en" and "newstest2014.de" under WMT'14 English-German data or download the dataset directly via: - -```bash -curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.en > newstest2014.en -curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2014.de > newstest2014.de -``` - -You should have 2737 sentences in each file. You can verify this by running: - -```bash -wc -l newstest2014.en # should give 2737 -``` - -### Usage - -Let's check the longest and shortest sentence in our file to find reasonable decoding hyperparameters: - -Get the longest and shortest sentence: - -```bash -awk '{print NF}' newstest2014.en | sort -n | head -1 # shortest sentence has 2 word -awk '{print NF}' newstest2014.en | sort -n | tail -1 # longest sentence has 91 words -``` - -We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0. -We decode with beam search `num_beams=4` as proposed in the paper. Also as is common in beam search we set `early_stopping=True` and `length_penalty=2.0`. - -To create translation for each in dataset and get a final BLEU score, run: -```bash -python evaluate_wmt.py newstest2014_de_translations.txt newsstest2014_en_de_bleu.txt -``` -the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system. - -### Where is the code? -The core model is in `src/transformers/modeling_t5.py`. This directory only contains examples. - -### BLEU Scores - -The BLEU score is calculated using [sacrebleu](https://github.com/mjpost/sacreBLEU) by mjpost. -To get the BLEU score we used diff --git a/examples/translation/t5/__init__.py b/examples/translation/t5/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/examples/translation/t5/evaluate_wmt.py b/examples/translation/t5/evaluate_wmt.py deleted file mode 100644 index b2be05a95033..000000000000 --- a/examples/translation/t5/evaluate_wmt.py +++ /dev/null @@ -1,103 +0,0 @@ -import argparse -from pathlib import Path - -import torch -from sacrebleu import corpus_bleu -from tqdm import tqdm - -from transformers import T5ForConditionalGeneration, T5Tokenizer - - -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - -def generate_translations(lns, output_file_path, model_size, batch_size, device): - model = T5ForConditionalGeneration.from_pretrained(model_size) - model.to(device) - - tokenizer = T5Tokenizer.from_pretrained(model_size) - - # update config with summarization specific params - task_specific_params = model.config.task_specific_params - if task_specific_params is not None: - model.config.update(task_specific_params.get("translation_en_to_de", {})) - - with Path(output_file_path).open("w") as output_file: - for batch in tqdm(list(chunks(lns, batch_size))): - batch = [model.config.prefix + text for text in batch] - - dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) - - input_ids = dct["input_ids"].to(device) - attention_mask = dct["attention_mask"].to(device) - - translations = model.generate(input_ids=input_ids, attention_mask=attention_mask) - dec = [ - tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations - ] - - for hypothesis in dec: - output_file.write(hypothesis + "\n") - - -def calculate_bleu_score(output_lns, refs_lns, score_path): - bleu = corpus_bleu(output_lns, [refs_lns]) - result = "BLEU score: {}".format(bleu.score) - with Path(score_path).open("w") as score_file: - score_file.write(result) - - -def run_generate(): - parser = argparse.ArgumentParser() - parser.add_argument( - "model_size", - type=str, - help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.", - default="t5-base", - ) - parser.add_argument( - "input_path", type=str, help="like wmt/newstest2014.en", - ) - parser.add_argument( - "output_path", type=str, help="where to save translation", - ) - parser.add_argument( - "reference_path", type=str, help="like wmt/newstest2014.de", - ) - parser.add_argument( - "score_path", type=str, help="where to save the bleu score", - ) - parser.add_argument( - "--batch_size", type=int, default=16, required=False, help="batch size: how many to summarize at a time", - ) - parser.add_argument( - "--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.", - ) - - args = parser.parse_args() - args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") - - dash_pattern = (" ##AT##-##AT## ", "-") - - # Read input lines into python - with open(args.input_path, "r") as input_file: - input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in input_file.readlines()] - - generate_translations(input_lns, args.output_path, args.model_size, args.batch_size, args.device) - - # Read generated lines into python - with open(args.output_path, "r") as output_file: - output_lns = [x.strip() for x in output_file.readlines()] - - # Read reference lines into python - with open(args.reference_path, "r") as reference_file: - refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in reference_file.readlines()] - - calculate_bleu_score(output_lns, refs_lns, args.score_path) - - -if __name__ == "__main__": - run_generate() diff --git a/examples/translation/t5/test_t5_examples.py b/examples/translation/t5/test_t5_examples.py deleted file mode 100644 index b33cba11c2da..000000000000 --- a/examples/translation/t5/test_t5_examples.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging -import sys -import tempfile -import unittest -from pathlib import Path -from unittest.mock import patch - -from .evaluate_wmt import run_generate - - -text = ["When Liana Barrientos was 23 years old, she got married in Westchester County."] -translation = ["Als Liana Barrientos 23 Jahre alt war, heiratete sie in Westchester County."] - -output_file_name = "output_t5_trans.txt" -score_file_name = "score_t5_trans.txt" - -logging.basicConfig(level=logging.DEBUG) - -logger = logging.getLogger() - - -class TestT5Examples(unittest.TestCase): - def test_t5_cli(self): - stream_handler = logging.StreamHandler(sys.stdout) - logger.addHandler(stream_handler) - - tmp_source = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.hypo" - with tmp_source.open("w") as f: - f.write("\n".join(text)) - - tmp_target = Path(tempfile.gettempdir()) / "utest_generations_t5_trans.target" - with tmp_target.open("w") as f: - f.write("\n".join(translation)) - - output_file_name = Path(tempfile.gettempdir()) / "utest_output_trans.hypo" - score_file_name = Path(tempfile.gettempdir()) / "utest_score.hypo" - - testargs = [ - "evaluate_wmt.py", - "patrickvonplaten/t5-tiny-random", - str(tmp_source), - str(output_file_name), - str(tmp_target), - str(score_file_name), - ] - - with patch.object(sys, "argv", testargs): - run_generate() - self.assertTrue(Path(output_file_name).exists()) - self.assertTrue(Path(score_file_name).exists()) diff --git a/setup.cfg b/setup.cfg index 0b4c1af0714c..aa1dfcf111e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,7 @@ known_third_party = pandas PIL psutil + pytest pytorch_lightning rouge_score sacrebleu diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index cf4f01c8d0f8..19538cef37bd 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -55,7 +55,7 @@ class BartTokenizerFast(RobertaTokenizerFast): } -_all_mbart_models = ["facebook/mbart-large-en-ro"] +_all_mbart_models = ["facebook/mbart-large-en-ro", "sshleifer/mbart-large-cc25"] SPM_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/sentence.bpe.model" @@ -105,6 +105,7 @@ class MBartTokenizer(XLMRobertaTokenizer): "vi_VN": 250024, "zh_CN": 250025, } + id_to_lang_code = {v: k for k, v in lang_code_to_id.items()} cur_lang_code = lang_code_to_id["en_XX"] def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: @@ -115,6 +116,16 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> Lis # We don't expect to process pairs, but leave the pair logic for API consistency return token_ids_0 + token_ids_1 + special_tokens + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.id_to_lang_code: + return self.id_to_lang_code[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def set_lang(self, lang: str) -> None: + """Set the current language code in order to call batch_encode_plus properly.""" + self.cur_lang_code = self.lang_code_to_id[lang] + def prepare_translation_batch( self, src_texts: List[str],