From e625e1c850b9dd0fad7410178f5bb829911c0e1e Mon Sep 17 00:00:00 2001 From: Alexander Rush Date: Fri, 2 Mar 2018 19:33:24 -0500 Subject: [PATCH 1/2] . --- docs/source/options/preprocess.md | 9 +++++-- docs/source/options/train.md | 40 +++++++++++++++++++++++++++++-- docs/source/options/translate.md | 29 ++++++++++++---------- onmt/io/IO.py | 34 ++++++++++++++++++++++++-- opts.py | 6 +++-- 5 files changed, 97 insertions(+), 21 deletions(-) diff --git a/docs/source/options/preprocess.md b/docs/source/options/preprocess.md index f9db3eb165..f96cd6df2d 100644 --- a/docs/source/options/preprocess.md +++ b/docs/source/options/preprocess.md @@ -26,12 +26,17 @@ Source directory for image or audio files. * **-save_data []** Output file for the prepared data +* **-max_shard_size []** +For text corpus of large volume, it will be divided into shards of this size to +preprocess. If 0, the data will be handled as a whole. The unit is in bytes. +Optimal value should be multiples of 64 bytes. + ### **Vocab**: * **-src_vocab []** -Path to an existing source vocabulary +Path to an existing source vocabulary. Format: one word per line. * **-tgt_vocab []** -Path to an existing target vocabulary +Path to an existing target vocabulary. Format: one word per line. * **-features_vocabs_prefix []** Path prefix to existing features vocabularies diff --git a/docs/source/options/train.md b/docs/source/options/train.md index ca3a4009ec..7e7c2c4ec7 100644 --- a/docs/source/options/train.md +++ b/docs/source/options/train.md @@ -70,6 +70,10 @@ layer Feed the context vector at each time step as additional input (via concatenation with the word embeddings) to the decoder. +* **-bridge []** +Have an additional layer between the last encoder state and the first decoder +state + * **-rnn_type [LSTM]** The gate type to use in the RNNs @@ -92,6 +96,12 @@ Train copy attention layer. * **-copy_attn_force []** When available, train to copy. +* **-reuse_copy_attn []** +Reuse standard attention for copy + +* **-copy_loss_by_seqlength []** +Divide copy loss by length of sequence + * **-coverage_attn []** Train a coverage attention layer. @@ -136,11 +146,25 @@ the decoder side. See README for specific formatting instructions. Fix word embeddings on the encoder side. * **-fix_word_vecs_dec []** -Fix word embeddings on the decoder side. +Fix word embeddings on the encoder side. ### **Optimization- Type**: * **-batch_size [64]** -Maximum batch size +Maximum batch size for training + +* **-batch_type [sents]** +Batch grouping for batch_size. Standard is sents. Tokens will do dynamic +batching + +* **-normalization [sents]** +Normalization method of the gradient. + +* **-accum_count [1]** +Accumulate gradient this many times. Approximately equivalent to updating +batch_size * accum_count batches at once. Recommended for Transformer. + +* **-valid_batch_size [32]** +Maximum batch size for validation * **-max_generator_batches [32]** Maximum batches of words in a sequence to run the generator on in parallel. @@ -181,6 +205,11 @@ https://keras.io/optimizers/ . Whereas recently the paper "Attention is All You Need" suggested a value of 0.98 for beta2, this parameter may not work well for normal models / default baselines. +* **-label_smoothing []** +Label smoothing value epsilon. Probabilities of all non-true labels will be +smoothed by epsilon / (vocab_size - 1). Set to zero to turn off label smoothing. +For more detailed information, see: https://arxiv.org/abs/1512.00567 + ### **Optimization- Rate**: * **-learning_rate [1.0]** Starting learning rate. Recommended settings: sgd = 1, adagrad = 0.1, adadelta = @@ -212,6 +241,13 @@ Send logs to this crayon server. * **-exp []** Name of the experiment for logging. +* **-tensorboard []** +Use tensorboardX for visualization during training. Must have the library +tensorboardX. + +* **-tensorboard_log_dir [runs]** +Log directory for Tensorboard. + ### **Speech**: * **-sample_rate [16000]** Sample rate. diff --git a/docs/source/options/translate.md b/docs/source/options/translate.md index c1010a4700..cf23ec1dfa 100644 --- a/docs/source/options/translate.md +++ b/docs/source/options/translate.md @@ -24,6 +24,13 @@ True target sequence (optional) * **-output [pred.txt]** Path to output the predictions (each line will be the decoded sequence +* **-report_bleu []** +Report bleu score after translation, call tools/multi-bleu.perl on command line + +* **-report_rouge []** +Report rouge 1/2/3/L/SU4 score after translation call tools/test_rouge.py on +command line + * **-dynamic_dict []** Create dynamic dictionaries @@ -34,15 +41,21 @@ Share source and target vocabulary * **-beam_size [5]** Beam size +* **-min_length []** +Minimum prediction length + +* **-max_length [100]** +Maximum prediction length. + +* **-max_sent_length []** +Deprecated, use `-max_length` instead + * **-alpha []** Google NMT length penalty parameter (higher = longer generation) * **-beta []** Coverage penalty parameter -* **-max_sent_length [100]** -Maximum sentence length. - * **-replace_unk []** Replace the generated UNK tokens with the source token that had highest attention weight. If phrase_table is provided, it will lookup the identified @@ -82,13 +95,3 @@ Window stride for spectrogram in seconds * **-window [hamming]** Window type for spectrogram generation - -### **Score**: -* **-report_bleu []** -Report bleu score after translation by calling tools/multi-bleu.perl -on command line. - -* **-report_rouge []** -Report Report rouge 1/2/3/L/SU4 score after translation by calling -tools/multi-bleu.perl on command line. Use pyrouge as backend. Scores may be -slightly different with those by calling files2rouge. \ No newline at end of file diff --git a/onmt/io/IO.py b/onmt/io/IO.py index 7b3edb7fda..538be5b8c2 100644 --- a/onmt/io/IO.py +++ b/onmt/io/IO.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +import os from collections import Counter, defaultdict, OrderedDict from itertools import count @@ -226,17 +227,19 @@ def _build_field_vocab(field, counter, **kwargs): def build_vocab(train_dataset_files, fields, data_type, share_vocab, - src_vocab_size, src_words_min_frequency, - tgt_vocab_size, tgt_words_min_frequency): + src_vocab_path, src_vocab_size, src_words_min_frequency, + tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency): """ Args: train_dataset_files: a list of train dataset pt file. fields (dict): fields to build vocab for. data_type: "text", "img" or "audio"? share_vocab(bool): share source and target vocabulary? + src_vocab_path(string): Path to src vocabulary file. src_vocab_size(int): size of the source vocabulary. src_words_min_frequency(int): the minimum frequency needed to include a source word in the vocabulary. + tgt_vocab_path(string): Path to tgt vocabulary file. tgt_vocab_size(int): size of the target vocabulary. tgt_words_min_frequency(int): the minimum frequency needed to include a target word in the vocabulary. @@ -248,6 +251,29 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab, for k in fields: counter[k] = Counter() + # Load vocabulary + src_vocab = None + if len(src_vocab_path) > 0: + src_vocab = set([]) + print ('Loading source vocab from %s' % src_vocab_path) + assert os.path.exists(src_vocab_path), \ + 'src vocab %s not found!' % src_vocab_path + with open(src_vocab_path) as f: + for line in f: + word = line.strip().split()[0] + src_vocab.add(word) + + tgt_vocab = None + if len(tgt_vocab_path) > 0: + tgt_vocab = set([]) + print ('Loading target vocab from %s' % tgt_vocab_path) + assert os.path.exists(tgt_vocab_path), \ + 'tgt vocab %s not found!' % tgt_vocab_path + with open(tgt_vocab_path) as f: + for line in f: + word = line.strip().split()[0] + tgt_vocab.add(word) + for path in train_dataset_files: dataset = torch.load(path) print(" * reloading %s." % path) @@ -256,6 +282,10 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab, val = getattr(ex, k, None) if val is not None and not fields[k].sequential: val = [val] + elif k == 'src' and src_vocab: + val = [item for item in val if item in src_vocab] + elif k == 'tgt' and tgt_vocab: + val = [item for item in val if item in tgt_vocab] counter[k].update(val) _build_field_vocab(fields["tgt"], counter["tgt"], diff --git a/opts.py b/opts.py index 57e5a1dc30..1d557f04b3 100644 --- a/opts.py +++ b/opts.py @@ -153,9 +153,11 @@ def preprocess_opts(parser): group = parser.add_argument_group('Vocab') group.add_argument('-src_vocab', - help="Path to an existing source vocabulary") + help="""Path to an existing source vocabulary. Format: + one word per line.""") group.add_argument('-tgt_vocab', - help="Path to an existing target vocabulary") + help="""Path to an existing target vocabulary. Format: + one word per line.""") group.add_argument('-features_vocabs_prefix', type=str, default='', help="Path prefix to existing features vocabularies") group.add_argument('-src_vocab_size', type=int, default=50000, From 97760f26a31e6da06397edfa85102dd0136e6a5e Mon Sep 17 00:00:00 2001 From: Alexander Rush Date: Fri, 2 Mar 2018 20:24:55 -0500 Subject: [PATCH 2/2] . --- onmt/io/IO.py | 4 ++-- opts.py | 4 ++-- preprocess.py | 2 ++ test/test_preprocess.py | 14 ++++++++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/onmt/io/IO.py b/onmt/io/IO.py index 538be5b8c2..f61df17663 100644 --- a/onmt/io/IO.py +++ b/onmt/io/IO.py @@ -255,7 +255,7 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab, src_vocab = None if len(src_vocab_path) > 0: src_vocab = set([]) - print ('Loading source vocab from %s' % src_vocab_path) + print('Loading source vocab from %s' % src_vocab_path) assert os.path.exists(src_vocab_path), \ 'src vocab %s not found!' % src_vocab_path with open(src_vocab_path) as f: @@ -266,7 +266,7 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab, tgt_vocab = None if len(tgt_vocab_path) > 0: tgt_vocab = set([]) - print ('Loading target vocab from %s' % tgt_vocab_path) + print('Loading target vocab from %s' % tgt_vocab_path) assert os.path.exists(tgt_vocab_path), \ 'tgt vocab %s not found!' % tgt_vocab_path with open(tgt_vocab_path) as f: diff --git a/opts.py b/opts.py index 1d557f04b3..b7c9402dee 100644 --- a/opts.py +++ b/opts.py @@ -152,10 +152,10 @@ def preprocess_opts(parser): # Dictionary options, for text corpus group = parser.add_argument_group('Vocab') - group.add_argument('-src_vocab', + group.add_argument('-src_vocab', default="", help="""Path to an existing source vocabulary. Format: one word per line.""") - group.add_argument('-tgt_vocab', + group.add_argument('-tgt_vocab', default="", help="""Path to an existing target vocabulary. Format: one word per line.""") group.add_argument('-features_vocabs_prefix', type=str, default='', diff --git a/preprocess.py b/preprocess.py index e226cdcf95..70a4894328 100755 --- a/preprocess.py +++ b/preprocess.py @@ -157,8 +157,10 @@ def build_save_dataset(corpus_type, fields, opt): def build_save_vocab(train_dataset, fields, opt): fields = onmt.io.build_vocab(train_dataset, fields, opt.data_type, opt.share_vocab, + opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency, + opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency) diff --git a/test/test_preprocess.py b/test/test_preprocess.py index 9cf7500e3e..9e92ac0a5c 100644 --- a/test/test_preprocess.py +++ b/test/test_preprocess.py @@ -3,6 +3,7 @@ import unittest import glob import os +import codecs from collections import Counter import torchtext @@ -39,6 +40,13 @@ def __init__(self, *args, **kwargs): def dataset_build(self, opt): fields = onmt.io.get_fields("text", 0, 0) + if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0: + with codecs.open(opt.src_vocab, 'w', 'utf-8') as f: + f.write('a\nb\nc\nd\ne\nf\n') + if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0: + with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f: + f.write('a\nb\nc\nd\ne\nf\n') + train_data_files = preprocess.build_save_dataset('train', fields, opt) preprocess.build_save_vocab(train_data_files, fields, opt) @@ -48,6 +56,10 @@ def dataset_build(self, opt): # Remove the generated *pt files. for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): os.remove(pt) + if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab): + os.remove(opt.src_vocab) + if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab): + os.remove(opt.tgt_vocab) def test_merge_vocab(self): va = torchtext.vocab.Vocab(Counter('abbccc')) @@ -109,6 +121,8 @@ def test_method(self): ('share_vocab', True)], [('dynamic_dict', True), ('max_shard_size', 500000)], + [('src_vocab', '/tmp/src_vocab.txt'), + ('tgt_vocab', '/tmp/tgt_vocab.txt')], ] for p in test_databuild: