Skip to content

0.6.1 -> 0.6.2 #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ developers to train custom models for translation, summarization, language
modeling and other text generation tasks. It provides reference implementations
of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
Expand All @@ -18,7 +18,8 @@ of various sequence-to-sequence models, including:
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
- **_New_** [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- **_New_** [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md)
- **_New_** [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)

Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
Expand Down Expand Up @@ -88,7 +89,7 @@ We also have more detailed READMEs to reproduce results from specific papers:
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)

# Join the fairseq community

Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@
# built documents.
#
# The short X.Y version.
version = '0.6.1'
version = '0.6.2'
# The full version, including alpha/beta/rc tags.
release = '0.6.1'
release = '0.6.2'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
47 changes: 34 additions & 13 deletions eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from fairseq import options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
Expand Down Expand Up @@ -65,11 +66,22 @@ def main(parsed_args):
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
setattr(args, arg, getattr(parsed_args, arg))

# reduce tokens per sample by the required context window size
args.tokens_per_sample -= args.context_window
task = tasks.setup_task(args)

# Load dataset splits
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
dataset = task.dataset(args.gen_subset)
if args.context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=args.tokens_per_sample,
context_window=args.context_window,
pad_idx=task.source_dictionary.pad(),
)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))

# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
Expand All @@ -84,7 +96,7 @@ def main(parsed_args):
print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
dataset=dataset,
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(*[
Expand All @@ -97,7 +109,7 @@ def main(parsed_args):
).next_epoch_itr(shuffle=False)

gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary)
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)

score_sum = 0.
count = 0
Expand All @@ -107,7 +119,11 @@ def main(parsed_args):
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_toks = set(
i
for i in range(len(task.source_dictionary))
if task.source_dictionary[i].endswith(bpe_cont)
)
bpe_len = len(bpe_cont)
else:
bpe_toks = None
Expand All @@ -117,31 +133,36 @@ def main(parsed_args):

with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()

for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
continue

sample = utils.move_to_cuda(sample) if use_cuda else sample

gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])

for hypos_i in hypos:
hypo = hypos_i[0]
pos_scores = hypo['positional_scores']

tokens = hypo['tokens']
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()

skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
for i in range(tgt_len - 1):
if tokens[i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0

inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
task.target_dictionary.string(tokens[inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
Expand All @@ -150,9 +171,9 @@ def main(parsed_args):
w = ''
word_prob = []
is_bpe = False
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
for i in range(len(tokens)):
w_ind = tokens[i].item()
w += task.source_dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
is_bpe = True
Expand All @@ -161,7 +182,7 @@ def main(parsed_args):

next_prob = None
ind = i + 1
while ind < len(hypo['tokens']):
while ind < len(tokens):
if pos_scores[ind].item() != 0:
next_prob = pos_scores[ind]
break
Expand Down
1 change: 0 additions & 1 deletion examples/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
*/*
!*/*.sh
!*/*.md
26 changes: 0 additions & 26 deletions examples/conv_lm/README.md

This file was deleted.

40 changes: 32 additions & 8 deletions examples/language_model/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

## Pre-trained models

Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2)
Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)

## Example usage

Expand All @@ -16,6 +16,8 @@ These scripts provide an example of pre-processing data for the Language Modelin
Provides an example of pre-processing for [WikiText-103 language modeling task](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/):

Example usage:

Prepare data:
```
$ cd examples/language_model/
$ bash prepare-wikitext-103.sh
Expand All @@ -27,17 +29,39 @@ $ TEXT=examples/language_model/wikitext-103
$ fairseq-preprocess --only-source \
--trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \
--destdir data-bin/wikitext-103
```

Train a transformer language model with adaptive inputs ([Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](transformer_lm/README.md)):
```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample
$ mkdir -p checkpoints/transformer_wikitext-103
$ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d

# Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103
# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/transformer_wiki103/checkpoint_best.pt' \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024

```


Train a convolutional language model ([Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](conv_lm/README.md)):
```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample
$ mkdir -p checkpoints/fconv_wikitext-103
$ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/fconv_wikitext-103 \
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
--ddp-backend=no_c10d

# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/fconv_wiki103/checkpoint_best.pt'

```
19 changes: 19 additions & 0 deletions examples/language_model/conv_lm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)

## Example usage

See the [language modeling README](../README.md) for instructions on reproducing results for WikiText-103
using the `fconv_lm_dauphin_wikitext103` model architecture.

## Citation

```bibtex
@inproceedings{dauphin2017language,
title={Language Modeling with Gated Convolutional Networks},
author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
pages={933--941},
year={2017},
organization={JMLR}
}
```
14 changes: 7 additions & 7 deletions examples/language_model/prepare-wikitext-103.sh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ for ((i=0;i<${#URLS[@]};++i)); do
echo "$url not successfully downloaded."
exit -1
fi
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
fi
fi
done
cd ..
26 changes: 26 additions & 0 deletions examples/language_model/transformer_lm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Adaptive Input Representations for Neural Language Modeling (Baevski and Auli; 2018)

## Pre-trained models

Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)

## Example usage

See the [language modeling README](../language_model/README.md) for instructions on reproducing results for WikiText-103
using the `transformer_lm_wiki103` model architecture.

## Citation

```bibtex
@inproceedings{
baevski2018adaptive,
title={Adaptive Input Representations for Neural Language Modeling},
author={Alexei Baevski and Michael Auli},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=ByxZX20qFQ},
}
```
2 changes: 1 addition & 1 deletion fairseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# can be found in the PATENTS file in the same directory.

__all__ = ['pdb']
__version__ = '0.6.1'
__version__ = '0.6.2'

import fairseq.criterions
import fairseq.models
Expand Down
2 changes: 2 additions & 0 deletions fairseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset
Expand All @@ -35,6 +36,7 @@
'IndexedDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'LMContextWindowDataset',
'MonolingualDataset',
'RoundRobinZipDatasets',
'ShardedIterator',
Expand Down
Loading