Skip to content

Commit

Permalink
open source mbart (facebookresearch#1033)
Browse files Browse the repository at this point in the history
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �
Pull Request resolved: fairinternal/fairseq-py#1033

Differential Revision: D20122520

Pulled By: yinhanliu

fbshipit-source-id: e2fd93e2fa9b7a8e276acc4316a176ba3ceae4ed
  • Loading branch information
Yinhan Liu authored and facebook-github-bot committed Feb 27, 2020
1 parent f8b795f commit 5e79322
Show file tree
Hide file tree
Showing 10 changed files with 461 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ modeling and other text generation tasks.

### What's New:

- February 2020: [mBART model and code released](examples/mbart/README.md)
- February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german)
- December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
- November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
Expand Down Expand Up @@ -50,6 +51,7 @@ Fairseq provides reference implementations of various sequence-to-sequence model
- [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
- [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
- [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
- [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)] (examples/mbart/README.md)
- **Non-autoregressive Transformers**
- Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
- Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
Expand Down
98 changes: 98 additions & 0 deletions examples/mbart/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# MBART: Multilingual Denoising Pre-training for Neural Machine Translation
[https://arxiv.org/abs/2001.08210]

## Introduction

MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text.

## Pre-trained models

Model | Description | # params | Download
---|---|---|---
`mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz)
`mbart.ft.ro_en` | finetune mBART cc25 model on ro-en language pairs | 610M | [mbart.cc25.ft.enro.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz)

## Results

**[WMT16 EN-RO](https://www.statmt.org/wmt16/translation-task.html)**

_(test set, no additional data used)_

Model | en-ro | ro-en
---|---|---
`Random` | 34.3 | 34.0
`mbart.cc25` | 37.7 | 37.8
`mbart.enro.bilingual` | 38.5 | 38.5

## BPE data
# download model
wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz
tar -xzvf mbart.CC25.tar.gz
# bpe data
install SPM [here](https://github.com/google/sentencepiece)
```bash
SPM=/path/to/sentencepiece/build/src/spm_encode
MODEL=sentence.bpe.model
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${SRC} > ${DATA}/${TRAIN}.spm.${SRC} &
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${TGT} > ${DATA}/${TRAIN}.spm.${TGT} &
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${SRC} > ${DATA}/${VALID}.spm.${SRC} &
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${TGT} > ${DATA}/${VALID}.spm.${TGT} &
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${SRC} > ${DATA}/${TEST}.spm.${SRC} &
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${TGT} > ${DATA}/${TEST}.spm.${TGT} &
```

## Preprocess data

```bash
DICT=dict.txt
python preprocess.py \
--source-lang ${SRC} \
--target-lang ${TGT} \
--trainpref ${DATA}/${TRAIN}.spm \
--validpref ${DATA}/${VALID}.spm \
--testpref ${DATA}/${TEST}.spm \
--destdir ${DEST}/${NAME} \
--thresholdtgt 0 \
--thresholdsrc 0 \
--srcdict ${DICT} \
--tgtdict ${DICT} \
--workers 70
```

## Finetune on EN-RO
Finetune on mbart CC25

```bash
PRETRAIN=/path/to/model/mbart.cc25
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN

python train.py path_2_data --encoder-normalize-before --decoder-normalize-before --arch mbart_large --task translation_from_pretrained_bart --source-lang en_XX --target-lang ro_RO --criterion label_smoothed_cross_entropy --label-smoothing 0.2 --dataset-impl mmap --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 --max-tokens 1024 --update-freq 2 --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints --seed 222 --log-format simple --log-interval 2 --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler --restore-file $PRETRAIN --langs $langs --layernorm-embedding --ddp-backend no_c10d
```
## Generate on EN-RO
Get sacrebleu on finetuned en-ro model

set tokenizer [here](https://github.com/rsennrich/wmt16-scripts)
wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz
tar -xzvf mbart.cc25.ft.enro.tar.gz

```bash
model=model.pt
python generate.py path_2_data --path $model --task translation_from_pretrained_bart --gen-subset test -t ro_RO -s en_XX --bpe 'sentencepiece' --sentencepiece-vocab sentence.bpe.model --sacrebleu --remove-bpe 'sentencepiece' --max-sentences 32 --langs $langs > en_ro

cat en_ro | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.hyp
cat en_ro | grep -P "^T" |sort -V |cut -f 2- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.ref
sacrebleu -tok 'none' -s 'none' en_ro.ref < en_ro.hyp
```

## Citation

```bibtex
@article{liu2020multilingual,
title={Multilingual Denoising Pre-training for Neural Machine Translation},
author={Yinhan Liu and Jiatao Gu and Naman Goyal and Xian Li and Sergey Edunov and Marjan Ghazvininejad and Mike Lewis and Luke Zettlemoyer},
year={2020},
eprint={2001.08210},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
8 changes: 5 additions & 3 deletions fairseq/data/denoising_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def __init__(
mask_whole_words,
shuffle,
seed,
args
args,
eos=None
):
self.dataset = dataset

Expand All @@ -115,6 +116,7 @@ def __init__(
self.insert_ratio = args.insert
self.rotate_ratio = args.rotate
self.permute_sentence_ratio = args.permute_sentences
self.eos = (eos if eos is not None else vocab.eos())

if args.bpe != 'gpt2':
self.full_stop_index = self.vocab.index(".")
Expand Down Expand Up @@ -155,7 +157,7 @@ def set_epoch(self, epoch, **unused):
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
tokens = self.dataset[index]
assert tokens[-1] == self.vocab.eos()
assert tokens[-1] == self.eos
source, target = tokens, tokens.clone()

if self.permute_sentence_ratio > 0.0:
Expand All @@ -174,7 +176,7 @@ def __getitem__(self, index):
assert (source[1:-1] >= 1).all()
assert (source <= len(self.vocab)).all()
assert source[0] == self.vocab.bos()
assert source[-1] == self.vocab.eos()
assert source[-1] == self.eos
return {
'id': index,
'source': source,
Expand Down
5 changes: 3 additions & 2 deletions fairseq/data/language_pair_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
shuffle=True, input_feeding=True,
remove_eos_from_source=False, append_eos_to_target=False,
align_dataset=None,
append_bos=False
append_bos=False, eos=None
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
Expand All @@ -183,6 +183,7 @@ def __init__(
if self.align_dataset is not None:
assert self.tgt_sizes is not None, "Both source and target needed when alignments are provided"
self.append_bos = append_bos
self.eos = (eos if eos is not None else src_dict.eos())

def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
Expand Down Expand Up @@ -252,7 +253,7 @@ def collater(self, samples):
on the left if *left_pad_target* is ``True``.
"""
return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
samples, pad_idx=self.src_dict.pad(), eos_idx=self.eos,
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
)
Expand Down
6 changes: 6 additions & 0 deletions fairseq/models/bart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,9 @@ def bart_large_architecture(args):
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)


@register_model_architecture('bart', 'mbart_large')
def mbart_large_architecture(args):
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
bart_large_architecture(args)
3 changes: 2 additions & 1 deletion fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
match_source_len=False,
no_repeat_ngram_size=0,
search_strategy=None,
eos=None
):
"""Generates translations of a given source sentence.
Expand All @@ -54,7 +55,7 @@ def __init__(
"""
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.eos = tgt_dict.eos() if eos is None else eos
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
Expand Down
4 changes: 2 additions & 2 deletions fairseq/sequence_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
class SequenceScorer(object):
"""Scores the target for a given source sentence."""

def __init__(self, tgt_dict, softmax_batch=None, compute_alignment=False):
def __init__(self, tgt_dict, softmax_batch=None, compute_alignment=False, eos=None):
self.pad = tgt_dict.pad()
self.eos = tgt_dict.eos()
self.eos = tgt_dict.eos() if eos is None else eos
self.softmax_batch = softmax_batch or sys.maxsize
assert self.softmax_batch > 0
self.compute_alignment = compute_alignment
Expand Down
Loading

0 comments on commit 5e79322

Please sign in to comment.