Skip to content

Commit

Permalink
Improvements to random sampling at decode time. (#1174)
Browse files Browse the repository at this point in the history
  • Loading branch information
daphnei authored and vince62s committed Jan 14, 2019
1 parent 111c854 commit 02b3976
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ script:
- head /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python preprocess.py -data_type audio -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-val-head.txt -train_tgt /tmp/speech/tgt-val-head.txt -valid_src /tmp/speech/src-val-head.txt -valid_tgt /tmp/speech/tgt-val-head.txt -save_data /tmp/speech/q; python train.py -model_type audio -data /tmp/speech/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/speech/q*.pt
# test nmt translation
- python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 10 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
# test nmt translation with random sampling
- python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 1 -seed 1 -random_sampling_topk "-1" -random_sampling_temp 0.0001 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
# test tool
- PYTHONPATH=$PYTHONPATH:. python tools/extract_embeddings.py -model onmt/tests/test_model.pt

Expand Down
2 changes: 2 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,8 @@ def translate_opts(parser):
default=1., type=float,
help="""If doing random sampling, divide the logits by
this before computing softmax during decoding.""")
group.add('--seed', '-seed', type=int, default=829,
help="Random seed")

group = parser.add_argument_group('Beam')
group.add('--fast', '-fast', action="store_true",
Expand Down
14 changes: 13 additions & 1 deletion onmt/tests/pull_request_chk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ ${PYTHON} translate.py -model ${TEST_DIR}/test_model2.pt \
-out /tmp/trans >> ${LOG_FILE} 2>&1
diff ${DATA_DIR}/morph/tgt.valid /tmp/trans
[ "$?" -eq 0 ] || error_exit

${PYTHON} translate.py -model ${TEST_DIR}/test_model2.pt \
-src ${DATA_DIR}/morph/src.valid \
-verbose -batch_size 10 \
-beam_size 1 \
-seed 1 \
-random_sampling_topk=-1 \
-random_sampling_temp=0.0001 \
-tgt ${DATA_DIR}/morph/tgt.valid \
-out /tmp/trans >> ${LOG_FILE} 2>&1
diff ${DATA_DIR}/morph/tgt.valid /tmp/trans
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}


Expand All @@ -186,7 +198,7 @@ ${PYTHON} preprocess.py -train_src /tmp/src-val.txt \
-save_data /tmp/q \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-max_shard_size 1 \
-shard_size 1 \
-dynamic_dict >> ${LOG_FILE} 2>&1
${PYTHON} train.py -data /tmp/q -rnn_size 2 -batch_size 10 \
-word_vec_size 5 -report_every 5 \
Expand Down
2 changes: 1 addition & 1 deletion onmt/tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_method(self):
[('dynamic_dict', True),
('share_vocab', True)],
[('dynamic_dict', True),
('max_shard_size', 500000)],
('shard_size', 500000)],
[('src_vocab', '/tmp/src_vocab.txt'),
('tgt_vocab', '/tmp/tgt_vocab.txt')],
]
Expand Down
15 changes: 2 additions & 13 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import os
import glob
import random
from itertools import chain

import torch
Expand All @@ -18,6 +17,7 @@
load_fields_from_vocab, old_style_vocab
from onmt.model_builder import build_model
from onmt.utils.optimizers import build_optim
from onmt.utils.misc import set_random_seed
from onmt.trainer import build_trainer
from onmt.models import build_model_saver
from onmt.utils.logging import init_logger, logger
Expand Down Expand Up @@ -70,20 +70,9 @@ def training_opt_postprocessing(opt, device_id):
logger.info("WARNING: You have a CUDA device, \
should run with -gpu_ranks")

if opt.seed > 0:
torch.manual_seed(opt.seed)
# this one is needed for torchtext random call (shuffled iterator)
# in multi gpu it ensures datasets are read in the same order
random.seed(opt.seed)
# some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = True

if device_id >= 0:
torch.cuda.set_device(device_id)
if opt.seed > 0:
# These ensure same initialization in multi gpu mode
torch.cuda.manual_seed(opt.seed)
set_random_seed(opt.seed, device_id >= 0)

return opt

Expand Down
3 changes: 3 additions & 0 deletions onmt/translate/translation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import onmt.opts

from onmt.utils.logging import init_logger
from onmt.utils.misc import set_random_seed
from onmt.translate.translator import build_translator


Expand Down Expand Up @@ -207,6 +208,8 @@ def __init__(self, opt, model_id, tokenizer_opt=None, load=False,
self.loading_lock.set()
self.running_lock = threading.Semaphore(value=1)

set_random_seed(opt.seed, opt.cuda)

if load:
self.load()

Expand Down
17 changes: 11 additions & 6 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import onmt.inputters as inputters
import onmt.opts as opts
import onmt.decoders.ensemble
from onmt.utils.misc import set_random_seed


def build_translator(opt, report_score=True, logger=None, out_file=None):
Expand Down Expand Up @@ -128,6 +129,8 @@ def __init__(
"scores": [],
"log_probs": []}

set_random_seed(opt.seed, self.cuda)

def translate(
self,
src,
Expand Down Expand Up @@ -294,12 +297,11 @@ def sample_with_temperature(self, logits, sampling_temp, keep_topk):
if keep_topk > 0:
top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
kth_best = top_values[:, -1].view([-1, 1])
kth_best = kth_best.repeat([1, logits.shape[1]])
kth_best = kth_best.type(torch.cuda.FloatTensor)
kth_best = kth_best.repeat([1, logits.shape[1]]).float()

# Set all logits that are not in the top-k to -1000.
# This puts the probabilities close to 0.
keep = torch.ge(logits, kth_best).type(torch.cuda.FloatTensor)
keep = torch.ge(logits, kth_best).float()
logits = (keep * logits) + ((1-keep) * -10000)

dist = torch.distributions.Multinomial(
Expand All @@ -326,9 +328,12 @@ def _translate_random_sampling(
assert self.block_ngram_repeat == 0

batch_size = batch.batch_size
vocab = self.fields["tgt"].vocab
start_token = vocab.stoi[self.fields["tgt"].init_token]
end_token = vocab.stoi[self.fields["tgt"].eos_token]

tgt_field = self.fields['tgt'][0][1]
vocab = tgt_field.vocab

start_token = vocab.stoi[tgt_field.init_token]
end_token = vocab.stoi[tgt_field.eos_token]

# Encoder forward.
src, enc_states, memory_bank, src_lengths = self._run_encoder(
Expand Down
4 changes: 2 additions & 2 deletions onmt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Module defining various utilities."""
from onmt.utils.misc import aeq, use_gpu
from onmt.utils.misc import aeq, use_gpu, set_random_seed
from onmt.utils.report_manager import ReportMgr, build_report_manager
from onmt.utils.statistics import Statistics
from onmt.utils.optimizers import build_optim, MultipleOptimizer, \
Optimizer, AdaFactor

__all__ = ["aeq", "use_gpu", "ReportMgr",
__all__ = ["aeq", "use_gpu", "set_random_seed", "ReportMgr",
"build_report_manager", "Statistics",
"build_optim", "MultipleOptimizer", "Optimizer", "AdaFactor"]
17 changes: 17 additions & 0 deletions onmt/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import torch
import random


def aeq(*args):
Expand Down Expand Up @@ -53,3 +54,19 @@ def use_gpu(opt):
"""
return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \
(hasattr(opt, 'gpu') and opt.gpu > -1)


def set_random_seed(seed, is_cuda):
"""Sets the random seed."""
if seed > 0:
torch.manual_seed(seed)
# this one is needed for torchtext random call (shuffled iterator)
# in multi gpu it ensures datasets are read in the same order
random.seed(seed)
# some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = True

if is_cuda and seed > 0:
# These ensure same initialization in multi gpu mode
torch.cuda.manual_seed(seed)

0 comments on commit 02b3976

Please sign in to comment.