Skip to content

Commit

Permalink
Revamp args (OpenNMT#1287)
Browse files Browse the repository at this point in the history
* Factor out validation and default opt getting
* Update test_models to use new parser.
* Remove unnecessary cast.
* Test translation server, format its docs.
* Fix Py27 compatibility.
* Actually fix Py27? And start testing TranslationServer.
  • Loading branch information
flauted authored and vince62s committed Feb 15, 2019
1 parent 64a8a9e commit 857e369
Show file tree
Hide file tree
Showing 12 changed files with 477 additions and 226 deletions.
6 changes: 2 additions & 4 deletions onmt/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,14 @@ def __init__(self, models, raw_probs=False):
self.models = nn.ModuleList(models)


def load_test_model(opt, dummy_opt):
def load_test_model(opt):
""" Read in multiple models for ensemble """
shared_fields = None
shared_model_opt = None
models = []
for model_path in opt.models:
fields, model, model_opt = \
onmt.model_builder.load_test_model(opt,
dummy_opt,
model_path=model_path)
onmt.model_builder.load_test_model(opt, model_path=model_path)
if shared_fields is None:
shared_fields = fields
else:
Expand Down
2 changes: 1 addition & 1 deletion onmt/encoders/audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, src, lengths=None):
t, _, _ = memory_bank.size()
memory_bank = memory_bank.transpose(0, 2)
memory_bank = pool(memory_bank)
lengths = [int(math.floor((length - stride)/stride + 1))
lengths = [int(math.floor((length - stride) / stride + 1))
for length in lengths]
memory_bank = memory_bank.transpose(0, 2)
src = memory_bank
Expand Down
24 changes: 8 additions & 16 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from onmt.modules.util_class import Cast
from onmt.utils.misc import use_gpu
from onmt.utils.logging import logger
from onmt.utils.parse import ArgumentParser


def build_embeddings(opt, text_field, for_encoder=True):
Expand Down Expand Up @@ -77,13 +78,15 @@ def build_decoder(opt, embeddings):
return str2dec[dec_type].from_opt(opt, embeddings)


def load_test_model(opt, dummy_opt, model_path=None):
def load_test_model(opt, model_path=None):
if model_path is None:
model_path = opt.models[0]
checkpoint = torch.load(model_path,
map_location=lambda storage, loc: storage)

model_opt = checkpoint['opt']
model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
vocab = checkpoint['vocab']
if inputters.old_style_vocab(vocab):
fields = inputters.load_old_vocab(
Expand All @@ -92,9 +95,6 @@ def load_test_model(opt, dummy_opt, model_path=None):
else:
fields = vocab

for arg in dummy_opt:
if arg not in model_opt:
model_opt.__dict__[arg] = dummy_opt[arg]
model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint,
opt.gpu)
model.eval()
Expand All @@ -105,7 +105,9 @@ def load_test_model(opt, dummy_opt, model_path=None):
def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
"""
Args:
model_opt: the option loaded from checkpoint.
model_opt: the option loaded from checkpoint. It's important that
the opts have been updated and validated. See
:class:`onmt.utils.parse.ArgumentParser`.
fields: `Field` objects for the model.
gpu (bool): whether to use gpu.
checkpoint: the model gnerated by train phase, or a resumed snapshot
Expand All @@ -115,14 +117,6 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
the NMTModel.
"""

assert model_opt.model_type in ["text", "img", "audio"], \
"Unsupported model type %s" % model_opt.model_type

# for backward compatibility
if model_opt.rnn_size != -1:
model_opt.enc_rnn_size = model_opt.rnn_size
model_opt.dec_rnn_size = model_opt.rnn_size

# Build embeddings.
if model_opt.model_type == "text":
src_fields = [f for n, f in fields['src']]
Expand Down Expand Up @@ -222,8 +216,6 @@ def fix_key(s):
model.generator = generator
model.to(device)
if model_opt.model_dtype == 'fp16':
logger.warning('FP16 is experimental, the generated checkpoints may '
'be incompatible with a future version')
model.half()

return model
Expand Down
28 changes: 16 additions & 12 deletions onmt/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import configargparse
import copy
import unittest
import math
Expand All @@ -12,8 +11,9 @@
build_encoder, build_decoder
from onmt.encoders.image_encoder import ImageEncoder
from onmt.encoders.audio_encoder import AudioEncoder
from onmt.utils.parse import ArgumentParser

parser = configargparse.ArgumentParser(description='train.py')
parser = ArgumentParser(description='train.py')
onmt.opts.model_opts(parser)
onmt.opts.train_opts(parser)

Expand Down Expand Up @@ -127,8 +127,7 @@ def nmtmodel_forward(self, opt, source_l=3, bsize=1):
embeddings = build_embeddings(opt, word_field)
enc = build_encoder(opt, embeddings)

embeddings = build_embeddings(opt, word_field,
for_encoder=False)
embeddings = build_embeddings(opt, word_field, for_encoder=False)
dec = build_decoder(opt, embeddings)

model = onmt.models.model.NMTModel(enc, dec)
Expand Down Expand Up @@ -159,8 +158,7 @@ def imagemodel_forward(self, opt, tgt_l=2, bsize=1, h=15, w=17):
enc = ImageEncoder(
opt.enc_layers, opt.brnn, opt.enc_rnn_size, opt.dropout)

embeddings = build_embeddings(opt, word_field,
for_encoder=False)
embeddings = build_embeddings(opt, word_field, for_encoder=False)
dec = build_decoder(opt, embeddings)

model = onmt.models.model.NMTModel(enc, dec)
Expand Down Expand Up @@ -197,8 +195,7 @@ def audiomodel_forward(self, opt, tgt_l=7, bsize=3, t=37):
opt.audio_enc_pooling, opt.dropout,
opt.sample_rate, opt.window_size)

embeddings = build_embeddings(opt, word_field,
for_encoder=False)
embeddings = build_embeddings(opt, word_field, for_encoder=False)
dec = build_decoder(opt, embeddings)

model = onmt.models.model.NMTModel(enc, dec)
Expand All @@ -225,12 +222,11 @@ def _add_test(param_setting, methodname):
"""

def test_method(self):
opt = copy.deepcopy(self.opt)
if param_setting:
opt = copy.deepcopy(self.opt)
for param, setting in param_setting:
setattr(opt, param, setting)
else:
opt = self.opt
ArgumentParser.update_model_opts(opt)
getattr(self, methodname)(opt)
if param_setting:
name = 'test_' + methodname + "_" + "_".join(
Expand Down Expand Up @@ -305,5 +301,13 @@ def test_method(self):
for p in tests_nmtmodel:
p.append(('sample_rate', 5500))
p.append(('window_size', 0.03))
p.append(('audio_enc_pooling', '2'))
# when reasonable, set audio_enc_pooling to 2
for arg, val in p:
if arg == "layers" and int(val) > 2:
# Need lengths >= audio_enc_pooling**n_layers.
# That condition is unrealistic for large n_layers,
# so leave audio_enc_pooling at 1.
break
else:
p.append(('audio_enc_pooling', '2'))
_add_test(p, 'audiomodel_forward')
Loading

0 comments on commit 857e369

Please sign in to comment.