-
Notifications
You must be signed in to change notification settings - Fork 27.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add translation example * make style * adapt docstring * add gpu device as input for example * small renaming * better README
- Loading branch information
1 parent
b4fb94f
commit 5ad2ea0
Showing
5 changed files
with
171 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,6 @@ tensorboard | |
scikit-learn | ||
seqeval | ||
psutil | ||
sacrebleu | ||
rouge-score | ||
tensorflow_datasets | ||
tensorflow_datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
***This script evaluates the multitask pre-trained checkpoint for ``t5-base`` (see paper [here](https://arxiv.org/pdf/1910.10683.pdf)) on the English to German WMT dataset. Please note that the results in the paper were attained using a model fine-tuned on translation, so that results will be worse here by approx. 1.5 BLEU points*** | ||
|
||
### Intro | ||
|
||
This example shows how T5 (here the official [paper](https://arxiv.org/abs/1910.10683)) can be | ||
evaluated on the WMT English-German dataset. | ||
|
||
### Get the WMT Data | ||
|
||
To be able to reproduce the authors' results on WMT English to German, you first need to download | ||
the WMT14 en-de news datasets. | ||
Go on Stanford's official NLP [website](https://nlp.stanford.edu/projects/nmt/) and find "newstest2013.en" and "newstest2013.de" under WMT'14 English-German data or download the dataset directly via: | ||
|
||
```bash | ||
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.en > newstest2013.en | ||
curl https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2013.de > newstest2013.de | ||
``` | ||
|
||
You should have 3000 sentence in each file. You can verify this by running: | ||
|
||
```bash | ||
wc -l newstest2013.en # should give 3000 | ||
``` | ||
|
||
### Usage | ||
|
||
Let's check the longest and shortest sentence in our file to find reasonable decoding hyperparameters: | ||
|
||
Get the longest and shortest sentence: | ||
|
||
```bash | ||
awk '{print NF}' newstest2013.en | sort -n | head -1 # shortest sentence has 1 word | ||
awk '{print NF}' newstest2013.en | sort -n | tail -1 # longest sentence has 106 words | ||
``` | ||
|
||
We will set our `max_length` to ~3 times the longest sentence and leave `min_length` to its default value of 0. | ||
We decode with beam search `num_beams=4` as proposed in the paper. Also as is common in beam search we set `early_stopping=True` and `length_penalty=2.0`. | ||
|
||
To create translation for each in dataset and get a final BLEU score, run: | ||
```bash | ||
python evaluate_wmt.py <path_to_newstest2013.en> newstest2013_de_translations.txt <path_to_newstest2013.de> newsstest2013_en_de_bleu.txt | ||
``` | ||
the default batch size, 16, fits in 16GB GPU memory, but may need to be adjusted to fit your system. | ||
|
||
### Where is the code? | ||
The core model is in `src/transformers/modeling_t5.py`. This directory only contains examples. | ||
|
||
### BLEU Scores | ||
|
||
The BLEU score is calculated using [sacrebleu](https://github.com/mjpost/sacreBLEU) by mjpost. | ||
To get the BLEU score we used |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import argparse | ||
from pathlib import Path | ||
|
||
import torch | ||
from tqdm import tqdm | ||
|
||
from sacrebleu import corpus_bleu | ||
from transformers import T5ForConditionalGeneration, T5Tokenizer | ||
|
||
|
||
def chunks(lst, n): | ||
"""Yield successive n-sized chunks from lst.""" | ||
for i in range(0, len(lst), n): | ||
yield lst[i : i + n] | ||
|
||
|
||
def generate_translations(lns, output_file_path, batch_size, device): | ||
output_file = Path(output_file_path).open("w") | ||
|
||
model = T5ForConditionalGeneration.from_pretrained("t5-base") | ||
model.to(device) | ||
|
||
tokenizer = T5Tokenizer.from_pretrained("t5-base") | ||
|
||
# update config with summarization specific params | ||
task_specific_params = model.config.task_specific_params | ||
if task_specific_params is not None: | ||
model.config.update(task_specific_params.get("translation_en_to_de", {})) | ||
|
||
for batch in tqdm(list(chunks(lns, batch_size))): | ||
batch = [model.config.prefix + text for text in batch] | ||
|
||
dct = tokenizer.batch_encode_plus(batch, max_length=512, return_tensors="pt", pad_to_max_length=True) | ||
|
||
input_ids = dct["input_ids"].to(device) | ||
attention_mask = dct["attention_mask"].to(device) | ||
|
||
translations = model.generate(input_ids=input_ids, attention_mask=attention_mask) | ||
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in translations] | ||
|
||
for hypothesis in dec: | ||
output_file.write(hypothesis + "\n") | ||
output_file.flush() | ||
|
||
|
||
def calculate_bleu_score(output_lns, refs_lns, score_path): | ||
bleu = corpus_bleu(output_lns, [refs_lns]) | ||
result = "BLEU score: {}".format(bleu.score) | ||
score_file = Path(score_path).open("w") | ||
score_file.write(result) | ||
|
||
|
||
def run_generate(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"input_path", type=str, help="like wmt/newstest2013.en", | ||
) | ||
parser.add_argument( | ||
"output_path", type=str, help="where to save translation", | ||
) | ||
parser.add_argument( | ||
"reference_path", type=str, help="like wmt/newstest2013.de", | ||
) | ||
parser.add_argument( | ||
"score_path", type=str, help="where to save the bleu score", | ||
) | ||
parser.add_argument( | ||
"--batch_size", type=int, default=16, required=False, help="batch size: how many to summarize at a time", | ||
) | ||
parser.add_argument( | ||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.", | ||
) | ||
|
||
args = parser.parse_args() | ||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") | ||
|
||
dash_pattern = (" ##AT##-##AT## ", "-") | ||
|
||
input_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.input_path).readlines()] | ||
|
||
generate_translations(input_lns, args.output_path, args.batch_size, args.device) | ||
|
||
output_lns = [x.strip() for x in open(args.output_path).readlines()] | ||
refs_lns = [x.strip().replace(dash_pattern[0], dash_pattern[1]) for x in open(args.reference_path).readlines()] | ||
|
||
calculate_bleu_score(output_lns, refs_lns, args.score_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_generate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import logging | ||
import sys | ||
import tempfile | ||
import unittest | ||
from pathlib import Path | ||
from unittest.mock import patch | ||
|
||
from .evaluate_wmt import run_generate | ||
|
||
|
||
text = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
logger = logging.getLogger() | ||
|
||
|
||
class TestT5Examples(unittest.TestCase): | ||
def test_t5_cli(self): | ||
stream_handler = logging.StreamHandler(sys.stdout) | ||
logger.addHandler(stream_handler) | ||
tmp = Path(tempfile.gettempdir()) / "utest_generations.hypo" | ||
with tmp.open("w") as f: | ||
f.write("\n".join(text)) | ||
testargs = ["evaluate_cnn.py", str(tmp), "output.txt", str(tmp), "score.txt"] | ||
with patch.object(sys, "argv", testargs): | ||
run_generate() | ||
self.assertTrue(Path("output.txt").exists()) |