Skip to content

Commit

Permalink
Add wmt translation example (#3428)
Browse files Browse the repository at this point in the history
* add translation example

* make style

* adapt docstring

* add gpu device as input for example

* small renaming

* better README
  • Loading branch information
patrickvonplaten authored Mar 26, 2020
1 parent b4fb94f commit 5ad2ea0
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ tensorboard
scikit-learn
seqeval
psutil
sacrebleu
rouge-score
tensorflow_datasets
tensorflow_datasets
51 changes: 51 additions & 0 deletions examples/translation/t5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
***This script evaluates the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the English to German WMT dataset. Please note that the results in the paper were attained using a model fine-tuned on translation, so that results will be worse here by approx. 1.5 BLEU points***

### Intro

This example shows how T5 (here the official [paper](https://arxiv.org/abs/1910.10683)) can be
evaluated on the WMT English-German dataset.

### Get the WMT Data

To be able to reproduce the authors' results on WMT English to German, you first need to download
the WMT14 en-de news datasets.
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2013.en" and "newstest2013.de" under WMT'14 English-German data or download the dataset directly via:

```bash
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en > newstest2013.en
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de > newstest2013.de
```

You should have 3000 sentence in each file. You can verify this by running:

```bash
wc -l newstest2013.en # should give 3000
```

### Usage

Let's check the longest and shortest sentence in our file to find reasonable decoding hyperparameters:

Get the longest and shortest sentence:

```bash
awk '{print NF}' newstest2013.en | sort -n | head -1 # shortest sentence has 1 word
awk '{print NF}' newstest2013.en | sort -n | tail -1 # longest sentence has 106 words
```

We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0.
We decode with beam search `num_beams=4` as proposed in the paper. Also as is common in beam search we set `early_stopping=True` and `length_penalty=2.0`.

To create translation for each in dataset and get a final BLEU score, run:
```bash
python evaluate_wmt.py <path_to_newstest2013.en> newstest2013_de_translations.txt <path_to_newstest2013.de> newsstest2013_en_de_bleu.txt
```
the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system.

### Where is the code?
The core model is in `src/transformers/modeling_t5.py`. This directory only contains examples.

### BLEU Scores

The BLEU score is calculated using [sacrebleu](https://github.com/mjpost/sacreBLEU) by mjpost.
To get the BLEU score we used
Empty file.
90 changes: 90 additions & 0 deletions examples/translation/t5/evaluate_wmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
from pathlib import Path

import torch
from tqdm import tqdm

from sacrebleu import corpus_bleu
from transformers import T5ForConditionalGeneration, T5Tokenizer


def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]


def generate_translations(lns, output_file_path, batch_size, device):
output_file = Path(output_file_path).open("w")

model = T5ForConditionalGeneration.from_pretrained("t5-base")
model.to(device)

tokenizer = T5Tokenizer.from_pretrained("t5-base")

# update config with summarization specific params
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
model.config.update(task_specific_params.get("translation_en_to_de", {}))

for batch in tqdm(list(chunks(lns, batch_size))):
batch = [model.config.prefix + text for text in batch]

dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True)

input_ids = dct["input_ids"].to(device)
attention_mask = dct["attention_mask"].to(device)

translations = model.generate(input_ids=input_ids, attention_mask=attention_mask)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations]

for hypothesis in dec:
output_file.write(hypothesis + "\n")
output_file.flush()


def calculate_bleu_score(output_lns, refs_lns, score_path):
bleu = corpus_bleu(output_lns, [refs_lns])
result = "BLEU score: {}".format(bleu.score)
score_file = Path(score_path).open("w")
score_file.write(result)


def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_path", type=str, help="like wmt/newstest2013.en",
)
parser.add_argument(
"output_path", type=str, help="where to save translation",
)
parser.add_argument(
"reference_path", type=str, help="like wmt/newstest2013.de",
)
parser.add_argument(
"score_path", type=str, help="where to save the bleu score",
)
parser.add_argument(
"--batch_size", type=int, default=16, required=False, help="batch size: how many to summarize at a time",
)
parser.add_argument(
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
)

args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

dash_pattern = (" ##AT##-##AT## ", "-")

input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()]

generate_translations(input_lns, args.output_path, args.batch_size, args.device)

output_lns = [x.strip() for x in open(args.output_path).readlines()]
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()]

calculate_bleu_score(output_lns, refs_lns, args.score_path)


if __name__ == "__main__":
run_generate()
28 changes: 28 additions & 0 deletions examples/translation/t5/test_t5_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from .evaluate_wmt import run_generate


text = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()


class TestT5Examples(unittest.TestCase):
def test_t5_cli(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo"
with tmp.open("w") as f:
f.write("\n".join(text))
testargs = ["evaluate_cnn.py", str(tmp), "output.txt", str(tmp), "score.txt"]
with patch.object(sys, "argv", testargs):
run_generate()
self.assertTrue(Path("output.txt").exists())

0 comments on commit 5ad2ea0

Please sign in to comment.