Skip to content

Commit 0049349

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommitted
Multilingual training example (facebookresearch#527)
Summary: * Add example for multilingual translation on IWSLT'17 * Match dataset ordering for multilingual_translation and translation * Fix bug with LegacyDistributedDataParallel when calling forward of sub-modules Pull Request resolved: facebookresearch#527 Differential Revision: D14218372 Pulled By: myleott fbshipit-source-id: 2e3fe24aa39476bcc5c9af68ef9a40192db34a3b
1 parent 44d27e6 commit 0049349

10 files changed

+388
-23
lines changed

examples/translation/README.md

+68-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ $ fairseq-generate data-bin/iwslt14.tokenized.de-en \
9292
9393
```
9494

95-
9695
### prepare-wmt14en2de.sh
9796

9897
The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
@@ -163,3 +162,71 @@ $ fairseq-generate data-bin/fconv_wmt_en_fr \
163162
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
164163
165164
```
165+
166+
## Multilingual Translation
167+
168+
We also support training multilingual translation models. In this example we'll
169+
train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.
170+
171+
Note that we use slightly different preprocessing here than for the IWSLT'14
172+
En-De data above. In particular we learn a joint BPE code for all three
173+
languages and use interactive.py and sacrebleu for scoring the test set.
174+
175+
```
176+
# First install sacrebleu and sentencepiece
177+
$ pip install sacrebleu sentencepiece
178+
179+
# Then download and preprocess the data
180+
$ cd examples/translation/
181+
$ bash prepare-iwslt17-multilingual.sh
182+
$ cd ../..
183+
184+
# Binarize the de-en dataset
185+
$ TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
186+
$ fairseq-preprocess --source-lang de --target-lang en \
187+
--trainpref $TEXT/train.bpe.de-en --validpref $TEXT/valid.bpe.de-en \
188+
--joined-dictionary \
189+
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
190+
--workers 10
191+
192+
# Binarize the fr-en dataset
193+
# NOTE: it's important to reuse the en dictionary from the previous step
194+
$ fairseq-preprocess --source-lang fr --target-lang en \
195+
--trainpref $TEXT/train.bpe.fr-en --validpref $TEXT/valid.bpe.fr-en \
196+
--joined-dictionary --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
197+
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
198+
--workers 10
199+
200+
# Train a multilingual transformer model
201+
# NOTE: the command below assumes 1 GPU, but accumulates gradients from
202+
# 8 fwd/bwd passes to simulate training on 8 GPUs
203+
$ mkdir -p checkpoints/multilingual_transformer
204+
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
205+
--max-epoch 50 \
206+
--ddp-backend=no_c10d \
207+
--task multilingual_translation --lang-pairs de-en,fr-en \
208+
--arch multilingual_transformer_iwslt_de_en \
209+
--share-decoders --share-decoder-input-output-embed \
210+
--optimizer adam --adam-betas '(0.9, 0.98)'
211+
--lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
212+
--warmup-updates 4000 --warmup-init-lr '1e-07' \
213+
--label-smoothing 0.1 --criterion label_smoothed_cross_entropy
214+
--dropout 0.3 --weight-decay 0.0001 \
215+
--save-dir checkpoints/multilingual_transformer \
216+
--max-tokens 4000 \
217+
--update-freq 8
218+
219+
# Generate and score the test set with sacrebleu
220+
$ SRC=de
221+
$ sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
222+
| python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
223+
> iwslt17.test.${SRC}-en.${SRC}.bpe
224+
$ cat iwslt17.test.${SRC}-en.${SRC}.bpe | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
225+
--task multilingual_translation --source-lang ${SRC} --target-lang en \
226+
--path checkpoints/multilingual_transformer/checkpoint_best.pt \
227+
--buffer 2000 --batch-size 128 \
228+
--beam 5 --remove-bpe=sentencepiece \
229+
> iwslt17.test.${SRC}-en.en.sys
230+
$ grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
231+
| sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
232+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/bin/bash
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
SRCS=(
9+
"de"
10+
"fr"
11+
)
12+
TGT=en
13+
14+
ROOT=$(dirname "$0")
15+
SCRIPTS=$ROOT/../../scripts
16+
SPM_TRAIN=$SCRIPTS/spm_train.py
17+
SPM_ENCODE=$SCRIPTS/spm_encode.py
18+
19+
BPESIZE=16384
20+
ORIG=$ROOT/iwslt17_orig
21+
DATA=$ROOT/iwslt17.de_fr.en.bpe16k
22+
mkdir -p "$ORIG" "$DATA"
23+
24+
TRAIN_MINLEN=1 # remove sentences with <1 BPE token
25+
TRAIN_MAXLEN=250 # remove sentences with >250 BPE tokens
26+
27+
URLS=(
28+
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz"
29+
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz"
30+
)
31+
ARCHIVES=(
32+
"de-en.tgz"
33+
"fr-en.tgz"
34+
)
35+
VALID_SETS=(
36+
"IWSLT17.TED.dev2010.de-en IWSLT17.TED.tst2010.de-en IWSLT17.TED.tst2011.de-en IWSLT17.TED.tst2012.de-en IWSLT17.TED.tst2013.de-en IWSLT17.TED.tst2014.de-en IWSLT17.TED.tst2015.de-en"
37+
"IWSLT17.TED.dev2010.fr-en IWSLT17.TED.tst2010.fr-en IWSLT17.TED.tst2011.fr-en IWSLT17.TED.tst2012.fr-en IWSLT17.TED.tst2013.fr-en IWSLT17.TED.tst2014.fr-en IWSLT17.TED.tst2015.fr-en"
38+
)
39+
40+
# download and extract data
41+
for ((i=0;i<${#URLS[@]};++i)); do
42+
ARCHIVE=$ORIG/${ARCHIVES[i]}
43+
if [ -f "$ARCHIVE" ]; then
44+
echo "$ARCHIVE already exists, skipping download"
45+
else
46+
URL=${URLS[i]}
47+
wget -P "$ORIG" "$URL"
48+
if [ -f "$ARCHIVE" ]; then
49+
echo "$URL successfully downloaded."
50+
else
51+
echo "$URL not successfully downloaded."
52+
exit 1
53+
fi
54+
fi
55+
FILE=${ARCHIVE: -4}
56+
if [ -e "$FILE" ]; then
57+
echo "$FILE already exists, skipping extraction"
58+
else
59+
tar -C "$ORIG" -xzvf "$ARCHIVE"
60+
fi
61+
done
62+
63+
echo "pre-processing train data..."
64+
for SRC in "${SRCS[@]}"; do
65+
for LANG in "${SRC}" "${TGT}"; do
66+
cat "$ORIG/${SRC}-${TGT}/train.tags.${SRC}-${TGT}.${LANG}" \
67+
| grep -v '<url>' \
68+
| grep -v '<talkid>' \
69+
| grep -v '<keywords>' \
70+
| grep -v '<speaker>' \
71+
| grep -v '<reviewer' \
72+
| grep -v '<translator' \
73+
| grep -v '<doc' \
74+
| grep -v '</doc>' \
75+
| sed -e 's/<title>//g' \
76+
| sed -e 's/<\/title>//g' \
77+
| sed -e 's/<description>//g' \
78+
| sed -e 's/<\/description>//g' \
79+
| sed 's/^\s*//g' \
80+
| sed 's/\s*$//g' \
81+
> "$DATA/train.${SRC}-${TGT}.${LANG}"
82+
done
83+
done
84+
85+
echo "pre-processing valid data..."
86+
for ((i=0;i<${#SRCS[@]};++i)); do
87+
SRC=${SRCS[i]}
88+
VALID_SET=${VALID_SETS[i]}
89+
for FILE in ${VALID_SET[@]}; do
90+
for LANG in "$SRC" "$TGT"; do
91+
grep '<seg id' "$ORIG/${SRC}-${TGT}/${FILE}.${LANG}.xml" \
92+
| sed -e 's/<seg id="[0-9]*">\s*//g' \
93+
| sed -e 's/\s*<\/seg>\s*//g' \
94+
| sed -e "s/\’/\'/g" \
95+
> "$DATA/valid.${SRC}-${TGT}.${LANG}"
96+
done
97+
done
98+
done
99+
100+
# learn BPE with sentencepiece
101+
TRAIN_FILES=$(for SRC in "${SRCS[@]}"; do echo $DATA/train.${SRC}-${TGT}.${SRC}; echo $DATA/train.${SRC}-${TGT}.${TGT}; done | tr "\n" ",")
102+
echo "learning joint BPE over ${TRAIN_FILES}..."
103+
python "$SPM_TRAIN" \
104+
--input=$TRAIN_FILES \
105+
--model_prefix=$DATA/sentencepiece.bpe \
106+
--vocab_size=$BPESIZE \
107+
--character_coverage=1.0 \
108+
--model_type=bpe
109+
110+
# encode train/valid/test
111+
echo "encoding train/valid with learned BPE..."
112+
for SRC in "${SRCS[@]}"; do
113+
for LANG in "$SRC" "$TGT"; do
114+
python "$SPM_ENCODE" \
115+
--model "$DATA/sentencepiece.bpe.model" \
116+
--output_format=piece \
117+
--inputs "$DATA/train.${SRC}-${TGT}.${SRC} $DATA/train.${SRC}-${TGT}.${TGT}" \
118+
--outputs "$DATA/train.bpe.${SRC}-${TGT}.${SRC} $DATA/train.bpe.${SRC}-${TGT}.${TGT}" \
119+
--min-len $TRAIN_MINLEN --max-len $TRAIN_MAXLEN
120+
python "$SPM_ENCODE" \
121+
--model "$DATA/sentencepiece.bpe.model" \
122+
--output_format=piece \
123+
--inputs "$DATA/valid.${SRC}-${TGT}.${SRC} $DATA/valid.${SRC}-${TGT}.${TGT}" \
124+
--outputs "$DATA/valid.bpe.${SRC}-${TGT}.${SRC} $DATA/valid.bpe.${SRC}-${TGT}.${TGT}"
125+
done
126+
done

fairseq/data/round_robin_zip_datasets.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ def __init__(self, datasets, eval_key=None):
3939
self.longest_dataset = dataset
4040
self.longest_dataset_key = key
4141

42-
self._ordered_indices = OrderedDict([
43-
(key, dataset.ordered_indices())
44-
for key, dataset in datasets.items()
45-
])
42+
self._ordered_indices = None
4643

4744
def _map_index(self, key, index):
45+
assert self._ordered_indices is not None, \
46+
'Must call RoundRobinZipDatasets.ordered_indices() first'
4847
return self._ordered_indices[key][index % len(self.datasets[key])]
4948

5049
def __getitem__(self, index):
@@ -102,6 +101,14 @@ def size(self, index):
102101

103102
def ordered_indices(self):
104103
"""Ordered indices for batching."""
104+
if self._ordered_indices is None:
105+
# Call the underlying dataset's ordered_indices() here, so that we
106+
# get the same random ordering as we would have from using the
107+
# underlying dataset directly.
108+
self._ordered_indices = OrderedDict([
109+
(key, dataset.ordered_indices())
110+
for key, dataset in self.datasets.items()
111+
])
105112
return np.arange(len(self))
106113

107114
@property

fairseq/legacy_distributed_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __setstate__(self, state):
7575
self._register_grad_hook()
7676

7777
def forward(self, *inputs, **kwargs):
78-
self.need_reduction = True
7978
return self.module(*inputs, **kwargs)
8079

8180
def _register_grad_hook(self):
@@ -166,6 +165,7 @@ def reduction_fn():
166165
for p in self.module.parameters():
167166

168167
def allreduce_hook(*unused):
168+
self.need_reduction = True
169169
Variable._execution_engine.queue_callback(reduction_fn)
170170

171171
if p.requires_grad:

fairseq/tasks/fairseq_task.py

-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
226226
- logging outputs to display while training
227227
"""
228228
model.train()
229-
230229
loss, sample_size, logging_output = criterion(model, sample)
231230
if ignore_grad:
232231
loss *= 0

fairseq/trainer.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
5050
self._num_updates = 0
5151
self._optim_history = None
5252
self._optimizer = None
53+
self._prev_grad_norm = None
5354
self._wrapped_model = None
5455

5556
self.init_meters(args)
@@ -215,12 +216,15 @@ def train_step(self, samples, dummy_batch=False):
215216

216217
# gather logging outputs from all replicas
217218
if self.args.distributed_world_size > 1:
218-
logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
219-
[logging_outputs, sample_sizes, ooms],
220-
))
219+
logging_outputs, sample_sizes, ooms, prev_norms = \
220+
zip(*distributed_utils.all_gather_list(
221+
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
222+
))
221223
logging_outputs = list(chain.from_iterable(logging_outputs))
222224
sample_sizes = list(chain.from_iterable(sample_sizes))
223225
ooms = sum(ooms)
226+
assert all(norm == prev_norms[0] for norm in prev_norms), \
227+
'Fatal error: gradients are inconsistent between workers'
224228

225229
self.meters['oom'].update(ooms, len(samples))
226230
if ooms == self.args.distributed_world_size * len(samples):
@@ -246,6 +250,7 @@ def train_step(self, samples, dummy_batch=False):
246250

247251
# clip grads
248252
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
253+
self._prev_grad_norm = grad_norm
249254

250255
# take an optimization step
251256
self.optimizer.step()

preprocess.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -56,37 +56,37 @@ def build_dictionary(filenames, src=False, tgt=False):
5656
padding_factor=args.padding_factor,
5757
)
5858

59+
if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
60+
raise FileExistsError(dict_path(args.source_lang))
61+
if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
62+
raise FileExistsError(dict_path(args.target_lang))
63+
5964
if args.joined_dictionary:
60-
assert (
61-
not args.srcdict or not args.tgtdict
62-
), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
65+
assert not args.srcdict or not args.tgtdict, \
66+
"cannot use both --srcdict and --tgtdict with --joined-dictionary"
6367

6468
if args.srcdict:
6569
src_dict = task.load_dictionary(args.srcdict)
6670
elif args.tgtdict:
6771
src_dict = task.load_dictionary(args.tgtdict)
6872
else:
69-
assert (
70-
args.trainpref
71-
), "--trainpref must be set if --srcdict is not specified"
72-
src_dict = build_dictionary({train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True)
73+
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
74+
src_dict = build_dictionary(
75+
{train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
76+
)
7377
tgt_dict = src_dict
7478
else:
7579
if args.srcdict:
7680
src_dict = task.load_dictionary(args.srcdict)
7781
else:
78-
assert (
79-
args.trainpref
80-
), "--trainpref must be set if --srcdict is not specified"
82+
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
8183
src_dict = build_dictionary([train_path(args.source_lang)], src=True)
8284

8385
if target:
8486
if args.tgtdict:
8587
tgt_dict = task.load_dictionary(args.tgtdict)
8688
else:
87-
assert (
88-
args.trainpref
89-
), "--trainpref must be set if --tgtdict is not specified"
89+
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
9090
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
9191
else:
9292
tgt_dict = None

scripts/spm_decode.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from __future__ import absolute_import, division, print_function, unicode_literals
9+
10+
import argparse
11+
12+
import sentencepiece as spm
13+
14+
15+
def main():
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument("--model", required=True,
18+
help="sentencepiece model to use for decoding")
19+
parser.add_argument("--input", required=True, help="input file to decode")
20+
parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
21+
args = parser.parse_args()
22+
23+
sp = spm.SentencePieceProcessor()
24+
sp.Load(args.model)
25+
26+
if args.input_format == "piece":
27+
def decode(l):
28+
return "".join(sp.DecodePieces(l))
29+
elif args.input_format == "id":
30+
def decode(l):
31+
return "".join(sp.DecodeIds(l))
32+
else:
33+
raise NotImplementedError
34+
35+
def tok2int(tok):
36+
# remap reference-side <unk> (represented as <<unk>>) to 0
37+
return int(tok) if tok != "<<unk>>" else 0
38+
39+
with open(args.input, "r", encoding="utf-8") as h:
40+
for line in h:
41+
print(decode(list(map(tok2int, line.rstrip().split()))))
42+
43+
44+
if __name__ == "__main__":
45+
main()

0 commit comments

Comments
 (0)