-
Notifications
You must be signed in to change notification settings - Fork 26.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ported model] FSMT (FairSeq MachineTranslation) (#6940)
* ready for PR * cleanup * correct FSMT_PRETRAINED_MODEL_ARCHIVE_LIST * fix * perfectionism * revert change from another PR * odd, already committed this one * non-interactive upload workaround * backup the failed experiment * store langs in config * workaround for localizing model path * doc clean up as in #6956 * style * back out debug mode * document: run_eval.py --num_beams 10 * remove unneeded constant * typo * re-use bart's Attention * re-use EncoderLayer, DecoderLayer from bart * refactor * send to cuda and fp16 * cleanup * revert (moved to another PR) * better error message * document run_eval --num_beams * solve the problem of tokenizer finding the right files when model is local * polish, remove hardcoded config * add a note that the file is autogenerated to avoid losing changes * prep for org change, remove unneeded code * switch to model4.pt, update scores * s/python/bash/ * missing init (but doesn't impact the finetuned model) * cleanup * major refactor (reuse-bart) * new model, new expected weights * cleanup * cleanup * full link * fix model type * merge porting notes * style * cleanup * have to create a DecoderConfig object to handle vocab_size properly * doc fix * add note (not a public class) * parametrize * - add bleu scores integration tests * skip test if sacrebleu is not installed * cache heavy models/tokenizers * some tweaks * remove tokens that aren't used * more purging * simplify code * switch to using decoder_start_token_id * add doc * Revert "major refactor (reuse-bart)" This reverts commit 226dad1. * decouple from bart * remove unused code #1 * remove unused code #2 * remove unused code #3 * update instructions * clean up * move bleu eval to examples * check import only once * move data+gen script into files * reuse via import * take less space * add prepare_seq2seq_batch (auto-tested) * cleanup * recode test to use json instead of yaml * ignore keys not needed * use the new -y in transformers-cli upload -y * [xlm tok] config dict: fix str into int to match definition (#7034) * [s2s] --eval_max_generate_length (#7018) * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * extending to support allen_nlp wmt models - allow a specific checkpoint file to be passed - more arg settings - scripts for allen_nlp models * sync with changes * s/fsmt-wmt/wmt/ in model names * s/fsmt-wmt/wmt/ in model names (p2) * s/fsmt-wmt/wmt/ in model names (p3) * switch to a better checkpoint * typo * make non-optional args such - adjust tests where possible or skip when there is no other choice * consistency * style * adjust header * cards moved (model rename) * use best custom hparams * update info * remove old cards * cleanup * s/stas/facebook/ * update scores * s/allen_nlp/allenai/ * url maps aren't needed * typo * move all the doc / build /eval generators to their own scripts * cleanup * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * fix indent * duplicated line * style * use the correct add_start_docstrings * oops * resizing can't be done with the core approach, due to 2 dicts * check that the arg is a list * style * style Co-authored-by: Sam Shleifer <sshleifer@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
- Loading branch information
1 parent
492bb6a
commit 1eeb206
Showing
19 changed files
with
3,534 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
FSMT | ||
---------------------------------------------------- | ||
**DISCLAIMER:** If you see something strange, | ||
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign | ||
@stas00. | ||
|
||
Overview | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
FSMT (FairSeq MachineTranslation) models were introduced in "Facebook FAIR's WMT19 News Translation Task Submission" <this paper <https://arxiv.org/abs/1907.06616>__ by Nathan Ng, Kyra Yee, Alexei Baevski, Myle Ott, Michael Auli, Sergey Edunov. | ||
|
||
The abstract of the paper is the following: | ||
|
||
This paper describes Facebook FAIR's submission to the WMT19 shared news translation task. We participate in two language pairs and four language directions, English <-> German and English <-> Russian. Following our submission from last year, our baseline systems are large BPE-based transformer models trained with the Fairseq sequence modeling toolkit which rely on sampled back-translations. This year we experiment with different bitext data filtering schemes, as well as with adding filtered back-translated data. We also ensemble and fine-tune our models on domain-specific data, then decode using noisy channel model reranking. Our submissions are ranked first in all four directions of the human evaluation campaign. On En->De, our system significantly outperforms other systems as well as human translations. This system improves upon our WMT'18 submission by 4.5 BLEU points. | ||
|
||
The original code can be found here <https://github.com/pytorch/fairseq/tree/master/examples/wmt19>__. | ||
|
||
Implementation Notes | ||
~~~~~~~~~~~~~~~~~~~~ | ||
|
||
- FSMT uses source and target vocab pair, that aren't combined into one. It doesn't share embed tokens either. Its tokenizer is very similar to `XLMTokenizer` and the main model is derived from `BartModel`. | ||
|
||
|
||
FSMTForConditionalGeneration | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.FSMTForConditionalGeneration | ||
:members: forward | ||
|
||
|
||
FSMTConfig | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.FSMTConfig | ||
:members: | ||
|
||
|
||
FSMTTokenizer | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.FSMTTokenizer | ||
:members: | ||
|
||
|
||
FSMTModel | ||
~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.FSMTModel | ||
:members: forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#!/usr/bin/env python | ||
|
||
import io | ||
import json | ||
import subprocess | ||
|
||
|
||
pairs = [ | ||
["en", "ru"], | ||
["ru", "en"], | ||
["en", "de"], | ||
["de", "en"], | ||
] | ||
|
||
n_objs = 8 | ||
|
||
|
||
def get_all_data(pairs, n_objs): | ||
text = {} | ||
for src, tgt in pairs: | ||
pair = f"{src}-{tgt}" | ||
cmd = f"sacrebleu -t wmt19 -l {pair} --echo src".split() | ||
src_lines = subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode("utf-8").splitlines() | ||
cmd = f"sacrebleu -t wmt19 -l {pair} --echo ref".split() | ||
tgt_lines = subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode("utf-8").splitlines() | ||
text[pair] = {"src": src_lines[:n_objs], "tgt": tgt_lines[:n_objs]} | ||
return text | ||
|
||
|
||
text = get_all_data(pairs, n_objs) | ||
filename = "./fsmt_val_data.json" | ||
with io.open(filename, "w", encoding="utf-8") as f: | ||
bleu_data = json.dump(text, f, indent=2, ensure_ascii=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
{ | ||
"en-ru": { | ||
"src": [ | ||
"Welsh AMs worried about 'looking like muppets'", | ||
"There is consternation among some AMs at a suggestion their title should change to MWPs (Member of the Welsh Parliament).", | ||
"It has arisen because of plans to change the name of the assembly to the Welsh Parliament.", | ||
"AMs across the political spectrum are worried it could invite ridicule.", | ||
"One Labour AM said his group was concerned \"it rhymes with Twp and Pwp.\"", | ||
"For readers outside of Wales: In Welsh twp means daft and pwp means poo.", | ||
"A Plaid AM said the group as a whole was \"not happy\" and has suggested alternatives.", | ||
"A Welsh Conservative said his group was \"open minded\" about the name change, but noted it was a short verbal hop from MWP to Muppet." | ||
], | ||
"tgt": [ | ||
"Члены Национальной ассамблеи Уэльса обеспокоены, что \"выглядят как куклы\"", | ||
"Некоторые члены Национальной ассамблеи Уэльса в ужасе от предложения о том, что их наименование должно измениться на MPW (члены Парламента Уэльса).", | ||
"Этот вопрос был поднят в связи с планами по переименованию ассамблеи в Парламент Уэльса.", | ||
"Члены Национальной ассамблеи Уэльса всего политического спектра обеспокоены, что это может породить насмешки.", | ||
"Один из лейбористских членов Национальной ассамблеи Уэльса сказал, что его партия обеспокоена тем, что \"это рифмуется с Twp и Pwp\".", | ||
"Для читателей за предлами Уэльса: по-валлийски twp означает \"глупый\", а pwp означает \"какашка\".", | ||
"Член Национальной ассамблеи от Плайд сказал, что эта партия в целом \"не счастлива\" и предложил альтернативы.", | ||
"Представитель Консервативной партии Уэльса сказал, что его партия \"открыта\" к переименованию, но отметил, что между WMP и Muppet небольшая разница в произношении." | ||
] | ||
}, | ||
"ru-en": { | ||
"src": [ | ||
"Названо число готовящихся к отправке в Донбасс новобранцев из Украины", | ||
"Официальный представитель Народной милиции самопровозглашенной Луганской Народной Республики (ЛНР) Андрей Марочко заявил, что зимой 2018-2019 года Украина направит в Донбасс не менее 3 тыс. новобранцев.", | ||
"По его словам, таким образом Киев планирует \"хоть как-то доукомплектовать подразделения\".", | ||
"\"Нежелание граждан Украины проходить службу в рядах ВС Украины, массовые увольнения привели к низкой укомплектованности подразделений\", - рассказал Марочко, которого цитирует \"РИА Новости\".", | ||
"Он также не исключил, что реальные цифры призванных в армию украинцев могут быть увеличены в случае необходимости.", | ||
"В 2014-2017 годах Киев начал так называемую антитеррористическую операцию (АТО), которую позже сменили на операцию объединенных сил (ООС).", | ||
"Предполагалось, что эта мера приведет к усилению роли украинских силовиков в урегулировании ситуации.", | ||
"В конце августа 2018 года ситуация в Донбассе обострилась из-за убийства главы ДНР Александра Захарченко." | ||
], | ||
"tgt": [ | ||
"The number of new Ukrainian recruits ready to go to Donbass has become public", | ||
"Official representative of the peoples’ militia of the self-proclaimed Lugansk People’s Republic Andrey Marochko claimed that Ukrainian will send at least 3 thousand new recruits to Donbass in winter 2018-2019.", | ||
"This is how Kyiv tries “at least somehow to staff the units,” he said.", | ||
"“The unwillingness of Ukrainian citizens to serve in the Ukraine’s military forces, mass resignments lead to low understaffing,” said Marochko cited by RIA Novosti.", | ||
"Also, he doesn’t exclude that the real numbers of conscripts in the Ukrainian army can be raised is necessary.", | ||
"In 2014-2017, Kyiv started so-called antiterrorist operation, that ws later changed to the united forces operation.", | ||
"This measure was supposed to strengthen the role of the Ukrainian military in settling the situation.", | ||
"In the late August 2018, the situation in Donbass escalated as the DNR head Aleksandr Zakharchenko was killed." | ||
] | ||
}, | ||
"en-de": { | ||
"src": [ | ||
"Welsh AMs worried about 'looking like muppets'", | ||
"There is consternation among some AMs at a suggestion their title should change to MWPs (Member of the Welsh Parliament).", | ||
"It has arisen because of plans to change the name of the assembly to the Welsh Parliament.", | ||
"AMs across the political spectrum are worried it could invite ridicule.", | ||
"One Labour AM said his group was concerned \"it rhymes with Twp and Pwp.\"", | ||
"For readers outside of Wales: In Welsh twp means daft and pwp means poo.", | ||
"A Plaid AM said the group as a whole was \"not happy\" and has suggested alternatives.", | ||
"A Welsh Conservative said his group was \"open minded\" about the name change, but noted it was a short verbal hop from MWP to Muppet." | ||
], | ||
"tgt": [ | ||
"Walisische Ageordnete sorgen sich \"wie Dödel auszusehen\"", | ||
"Es herrscht Bestürzung unter einigen Mitgliedern der Versammlung über einen Vorschlag, der ihren Titel zu MWPs (Mitglied der walisischen Parlament) ändern soll.", | ||
"Der Grund dafür waren Pläne, den Namen der Nationalversammlung in Walisisches Parlament zu ändern.", | ||
"Mitglieder aller Parteien der Nationalversammlung haben Bedenken, dass sie sich dadurch Spott aussetzen könnten.", | ||
"Ein Labour-Abgeordneter sagte, dass seine Gruppe \"sich mit Twp und Pwp reimt\".", | ||
"Hinweis für den Leser: „twp“ im Walisischen bedeutet „bescheuert“ und „pwp“ bedeutet „Kacke“.", | ||
"Ein Versammlungsmitglied von Plaid Cymru sagte, die Gruppe als Ganzes sei \"nicht glücklich\" und hat Alternativen vorgeschlagen.", | ||
"Ein walisischer Konservativer sagte, seine Gruppe wäre „offen“ für eine Namensänderung, wies aber darauf hin, dass es von „MWP“ (Mitglied des Walisischen Parlaments) nur ein kurzer verbaler Sprung zu „Muppet“ ist." | ||
] | ||
}, | ||
"de-en": { | ||
"src": [ | ||
"Schöne Münchnerin 2018: Schöne Münchnerin 2018 in Hvar: Neun Dates", | ||
"Von az, aktualisiert am 04.05.2018 um 11:11", | ||
"Ja, sie will...", | ||
"\"Schöne Münchnerin\" 2018 werden!", | ||
"Am Nachmittag wartet erneut eine Überraschung auf unsere Kandidatinnen: sie werden das romantische Candlelight-Shooting vor der MY SOLARIS nicht alleine bestreiten, sondern an der Seite von Male-Model Fabian!", | ||
"Hvar - Flirten, kokettieren, verführen - keine einfachen Aufgaben für unsere Mädchen.", | ||
"Insbesondere dann, wenn in Deutschland ein Freund wartet.", | ||
"Dennoch liefern die neun \"Schöne Münchnerin\"-Kandidatinnen beim Shooting mit People-Fotograf Tuan ab und trotzen Wind, Gischt und Regen wie echte Profis." | ||
], | ||
"tgt": [ | ||
"The Beauty of Munich 2018: the Beauty of Munich 2018 in Hvar: Nine dates", | ||
"From A-Z, updated on 04/05/2018 at 11:11", | ||
"Yes, she wants to...", | ||
"to become \"The Beauty of Munich\" in 2018!", | ||
"In the afternoon there is another surprise waiting for our contestants: they will be competing for the romantic candlelight photo shoot at MY SOLARIS not alone, but together with a male-model Fabian!", | ||
"Hvar with its flirting, coquetting, and seduction is not an easy task for our girls.", | ||
"Especially when there is a boyfriend waiting in Germany.", | ||
"Despite dealing with wind, sprays and rain, the nine contestants of \"The Beauty of Munich\" behaved like real professionals at the photo shoot with People-photographer Tuan." | ||
] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 Huggingface | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import io | ||
import unittest | ||
|
||
|
||
try: | ||
from .utils import calculate_bleu | ||
except ImportError: | ||
from utils import calculate_bleu | ||
|
||
import json | ||
|
||
from parameterized import parameterized | ||
from transformers import FSMTForConditionalGeneration, FSMTTokenizer | ||
from transformers.testing_utils import get_tests_dir, require_torch, slow, torch_device | ||
|
||
|
||
filename = get_tests_dir() + "/test_data/fsmt/fsmt_val_data.json" | ||
with io.open(filename, "r", encoding="utf-8") as f: | ||
bleu_data = json.load(f) | ||
|
||
|
||
@require_torch | ||
class ModelEvalTester(unittest.TestCase): | ||
def get_tokenizer(self, mname): | ||
return FSMTTokenizer.from_pretrained(mname) | ||
|
||
def get_model(self, mname): | ||
model = FSMTForConditionalGeneration.from_pretrained(mname).to(torch_device) | ||
if torch_device == "cuda": | ||
model.half() | ||
return model | ||
|
||
@parameterized.expand( | ||
[ | ||
["en-ru", 26.0], | ||
["ru-en", 22.0], | ||
["en-de", 22.0], | ||
["de-en", 29.0], | ||
] | ||
) | ||
@slow | ||
def test_bleu_scores(self, pair, min_bleu_score): | ||
# note: this test is not testing the best performance since it only evals a small batch | ||
# but it should be enough to detect a regression in the output quality | ||
mname = f"facebook/wmt19-{pair}" | ||
tokenizer = self.get_tokenizer(mname) | ||
model = self.get_model(mname) | ||
|
||
src_sentences = bleu_data[pair]["src"] | ||
tgt_sentences = bleu_data[pair]["tgt"] | ||
|
||
batch = tokenizer(src_sentences, return_tensors="pt", truncation=True, padding="longest").to(torch_device) | ||
outputs = model.generate( | ||
input_ids=batch.input_ids, | ||
num_beams=8, | ||
) | ||
decoded_sentences = tokenizer.batch_decode( | ||
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
scores = calculate_bleu(decoded_sentences, tgt_sentences) | ||
print(scores) | ||
self.assertGreaterEqual(scores["bleu"], min_bleu_score) |
Oops, something went wrong.