Skip to content

Commit

Permalink
Seq2seq trainer (huggingface#9241)
Browse files Browse the repository at this point in the history
* Add label smoothing in Trainer

* Add options for scheduler and Adafactor in Trainer

* Put Seq2SeqTrainer in the main lib

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments and adapt scripts

* Documentation

* Move test not using script to tests folder

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
3 people authored Dec 22, 2020
1 parent 1fc7119 commit 490b39e
Show file tree
Hide file tree
Showing 20 changed files with 655 additions and 166 deletions.
4 changes: 4 additions & 0 deletions docs/source/main_classes/optimizer_schedules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ Schedules
Learning Rate Schedules (Pytorch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: transformers.SchedulerType

.. autofunction:: transformers.get_scheduler

.. autofunction:: transformers.get_constant_schedule


Expand Down
14 changes: 14 additions & 0 deletions docs/source/main_classes/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ Trainer
:members:


Seq2SeqTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Seq2SeqTrainer
:members: evaluate, predict


TFTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -77,6 +84,13 @@ TrainingArguments
:members:


Seq2SeqTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Seq2SeqTrainingArguments
:members:


TFTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
44 changes: 25 additions & 19 deletions examples/seq2seq/finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@
from typing import Optional

import transformers
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
HfArgumentParser,
MBartTokenizer,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.trainer_utils import EvaluationStrategy, is_main_process
from transformers.training_args import ParallelMode
from utils import (
Expand Down Expand Up @@ -86,28 +93,21 @@ class DataTrainingArguments:
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
max_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
eval_max_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
" This argument is also used to override the ``max_length`` param of ``model.generate``, which is used during ``evaluate`` and ``predict``"
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
Expand Down Expand Up @@ -233,7 +233,7 @@ def main():
type_path="train",
data_dir=data_args.data_dir,
n_obs=data_args.n_train,
max_target_length=data_args.max_target_length,
max_target_length=data_args.max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
Expand All @@ -246,7 +246,7 @@ def main():
type_path="val",
data_dir=data_args.data_dir,
n_obs=data_args.n_val,
max_target_length=data_args.val_max_target_length,
max_target_length=data_args.eval_max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
Expand All @@ -259,7 +259,7 @@ def main():
type_path="test",
data_dir=data_args.data_dir,
n_obs=data_args.n_test,
max_target_length=data_args.test_max_target_length,
max_target_length=data_args.eval_max_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
Expand All @@ -273,13 +273,12 @@ def main():
)
trainer = Seq2SeqTrainer(
model=model,
config=config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=compute_metrics_fn,
data_args=data_args,
tokenizer=tokenizer,
)

all_metrics = {}
Expand Down Expand Up @@ -310,7 +309,9 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(metric_key_prefix="val")
metrics = trainer.evaluate(
metric_key_prefix="val", max_length=data_args.eval_max_length, num_beams=data_args.eval_beams
)
metrics["val_n_objs"] = data_args.n_val
metrics["val_loss"] = round(metrics["val_loss"], 4)

Expand All @@ -322,7 +323,12 @@ def main():
if training_args.do_predict:
logger.info("*** Predict ***")

test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
test_output = trainer.predict(
test_dataset=test_dataset,
metric_key_prefix="test",
max_length=data_args.eval_max_length,
num_beams=data_args.eval_beams,
)
metrics = test_output.metrics
metrics["test_n_objs"] = data_args.n_test

Expand Down
123 changes: 4 additions & 119 deletions examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import unittest
from unittest.mock import patch

from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_apex_available, is_datasets_available
from transformers.file_utils import is_apex_available
from transformers.integrations import is_fairscale_available
from transformers.testing_utils import (
TestCasePlus,
Expand All @@ -31,8 +30,7 @@
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

from .finetune_trainer import Seq2SeqTrainingArguments, main
from .seq2seq_trainer import Seq2SeqTrainer
from .finetune_trainer import main


set_seed(42)
Expand Down Expand Up @@ -120,119 +118,6 @@ def test_finetune_trainer_slow(self):
assert "test_generations.txt" in contents
assert "test_results.json" in contents

@slow
def test_finetune_bert2bert(self):
if not is_datasets_available():
return

import datasets

bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.max_length = 128

train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")

train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16))

rouge = datasets.load_metric("rouge")

batch_size = 4

def _map_to_encoder_decoder_inputs(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask

batch["decoder_input_ids"] = outputs.input_ids
batch["labels"] = outputs.input_ids.copy()
batch["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
]
batch["decoder_attention_mask"] = outputs.attention_mask

assert all([len(x) == 512 for x in inputs.input_ids])
assert all([len(x) == 128 for x in outputs.input_ids])

return batch

def _compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions

# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
"rouge2"
].mid

return {
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}

# map train dataset
train_dataset = train_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
train_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# same for validation dataset
val_dataset = val_dataset.map(
_map_to_encoder_decoder_inputs,
batched=True,
batch_size=batch_size,
remove_columns=["article", "highlights"],
)
val_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

output_dir = self.get_auto_remove_tmp_dir()

training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
evaluation_strategy="steps",
do_train=True,
do_eval=True,
warmup_steps=0,
eval_steps=2,
logging_steps=2,
)

# instantiate trainer
trainer = Seq2SeqTrainer(
model=bert2bert,
args=training_args,
compute_metrics=_compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)

# start training
trainer.train()

def run_trainer(
self,
eval_steps: int,
Expand All @@ -252,8 +137,8 @@ def run_trainer(
--n_train 8
--n_val 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
--max_length {max_len}
--eval_max_length {max_len}
--do_train
--do_eval
--do_predict
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/train_distil_marian_enro.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ python finetune_trainer.py \
--freeze_encoder --freeze_embeds \
--num_train_epochs=6 \
--save_steps 3000 --eval_steps 3000 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --logging_first_step \
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/train_distil_marian_enro_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
--num_train_epochs=6 \
--save_steps 500 --eval_steps 500 \
--logging_first_step --logging_steps 200 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--max_source_length $MAX_LEN --max_length $MAX_LEN --eval_max_length $MAX_LEN \
--do_train --do_eval \
--evaluation_strategy steps \
--prediction_loss_only \
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/train_distilbart_cnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ python finetune_trainer.py \
--num_train_epochs=2 \
--save_steps 3000 --eval_steps 3000 \
--logging_first_step \
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--max_length 56 --eval_max_length $MAX_TGT_LEN \
--do_train --do_eval --do_predict \
--evaluation_strategy steps \
--predict_with_generate --sortish_sampler \
Expand Down
3 changes: 1 addition & 2 deletions examples/seq2seq/train_mbart_cc25_enro.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ python finetune_trainer.py \
--src_lang en_XX --tgt_lang ro_RO \
--freeze_embeds \
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
--max_source_length 128 --max_target_length 128 \
--val_max_target_length 128 --test_max_target_length 128 \
--max_source_length 128 --max_length 128 --eval_max_length 128 \
--sortish_sampler \
--num_train_epochs 6 \
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
Expand Down
2 changes: 1 addition & 1 deletion examples/seq2seq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _encode(self, batch) -> Dict[str, torch.Tensor]:
[x["src_texts"] for x in batch],
tgt_texts=[x["tgt_texts"] for x in batch],
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
max_target_length=self.data_args.max_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
**self.dataset_kwargs,
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@
TrainerControl,
TrainerState,
)
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging

Expand Down Expand Up @@ -682,11 +683,13 @@
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)

# Trainer
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer
else:
from .utils.dummy_pt_objects import *

Expand Down
Loading

0 comments on commit 490b39e

Please sign in to comment.