Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensemble #732

Merged
merged 16 commits into from
Aug 31, 2018
29 changes: 29 additions & 0 deletions ensemble_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python
from __future__ import division, unicode_literals
import argparse

from onmt.translate.Translator import make_translator

import onmt.io
import onmt.translate
import onmt
import onmt.ModelConstructor
import onmt.modules
import onmt.opts


def main(opt):
translator = make_translator(opt, report_score=True, use_ensemble=True)
translator.translate(opt.src_dir, opt.src, opt.tgt,
opt.batch_size, opt.attn_debug)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='ensemble_translate.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
onmt.opts.add_md_help_argument(parser)
onmt.opts.translate_opts(parser, use_ensemble=True)

opt = parser.parse_args()
main(opt)
6 changes: 4 additions & 2 deletions onmt/ModelConstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ def make_decoder(opt, embeddings):
opt.reuse_copy_attn)


def load_test_model(opt, dummy_opt):
checkpoint = torch.load(opt.model,
def load_test_model(opt, dummy_opt, model_path=None):
if model_path is None:
model_path = opt.model
checkpoint = torch.load(model_path,
map_location=lambda storage, loc: storage)
fields = onmt.io.load_fields_from_vocab(
checkpoint['vocab'], data_type=opt.data_type)
Expand Down
56 changes: 52 additions & 4 deletions onmt/modules/Ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
import torch.nn as nn

from onmt.Models import DecoderState, EncoderBase
from onmt.Models import DecoderState, EncoderBase, NMTModel
import onmt.ModelConstructor


class EnsembleDecoderState(DecoderState):
Expand Down Expand Up @@ -43,11 +44,15 @@ def squeeze(self, dim=None):
return EnsembleDecoderOutput([
x.squeeze(dim) for x in self.model_outputs])

def __getitem__(self, index):
return self.model_outputs[index]


class EnsembleEncoder(EncoderBase):
""" Dummy Encoder that delegates to individual real Encoders """
def __init__(self, model_encoders):
self.model_encoders = tuple(model_encoders)
super(EnsembleEncoder, self).__init__()
self.model_encoders = nn.ModuleList(list(model_encoders))

def forward(self, src, lengths=None, encoder_state=None):
enc_hidden, memory_bank = zip(*[
Expand All @@ -59,7 +64,8 @@ def forward(self, src, lengths=None, encoder_state=None):
class EnsembleDecoder(nn.Module):
""" Dummy Decoder that delegates to individual real Decoders """
def __init__(self, model_decoders):
self.model_decoders = tuple(model_decoders)
super(EnsembleDecoder, self).__init__()
self.model_decoders = nn.ModuleList(list(model_decoders))

def forward(self, tgt, memory_bank, state, memory_lengths=None):
""" See :obj:`RNNDecoderBase.forward()` """
Expand All @@ -70,9 +76,16 @@ def forward(self, tgt, memory_bank, state, memory_lengths=None):
tgt, memory_bank[i], state[i], memory_lengths[i])
for (i, model_decoder)
in enumerate(self.model_decoders)])
mean_attns = self.combine_attns(attns)
return (EnsembleDecoderOutput(outputs),
EnsembleDecoderState(states),
torch.stack(attns).mean(0))
mean_attns)

def combine_attns(self, attns):
result = {}
for key in attns[0].keys():
result[key] = torch.stack([attn[key] for attn in attns]).mean(0)
return result

def init_decoder_state(self, src, memory_bank, enc_hidden):
""" See :obj:`RNNDecoderBase.init_decoder_state()` """
Expand All @@ -99,3 +112,38 @@ def forward(self, hidden):
for (i, model_generator)
in enumerate(self.model_generators)]
return torch.stack(distributions).mean(0)


class EnsembleModel(NMTModel):
""" Dummy NMTModel wrapping individual real NMTModels """
def __init__(self, models):
encoder = EnsembleEncoder(model.encoder for model in models)
decoder = EnsembleDecoder(model.decoder for model in models)
super(EnsembleModel, self).__init__(encoder, decoder)
self.generator = EnsembleGenerator(model.generator for model in models)
self.models = nn.ModuleList(models)


def load_test_model(opt, dummy_opt):
""" Read in multiple models for ensemble """
shared_fields = None
shared_model_opt = None
models = []
for model_path in opt.models:
fields, model, model_opt = \
onmt.ModelConstructor.load_test_model(opt,
dummy_opt,
model_path=model_path)
if shared_fields is None:
shared_fields = fields
else:
for key, field in fields.items():
if field is not None and 'vocab' in field.__dict__:
assert field.vocab.stoi == shared_fields[key].vocab.stoi, \
'Ensemble models must use the same preprocessed data'
models.append(model)
if shared_model_opt is None:
shared_model_opt = model_opt
# FIXME: anything to check or copy from other model_opt?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this #FIXME Repaired?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for now, the only value of model_opt that is accessed after this point is copy_attn.
It is not possible to ensemble models using copy attention with those not using it.
A check could be added, but in that case checks should be added also for other incompatible configurations, which is too many.

I'm considering this FIXME resolved for now.

ensemble_model = EnsembleModel(models)
return shared_fields, ensemble_model, shared_model_opt
11 changes: 8 additions & 3 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,15 @@ def train_opts(parser):
help="Window size for spectrogram in seconds.")


def translate_opts(parser):
def translate_opts(parser, use_ensemble=False):
group = parser.add_argument_group('Model')
group.add_argument('-model', required=True,
help='Path to model .pt file')
if use_ensemble:
group.add_argument('-model', dest='models', action='append',
required=True,
help='Path to model .pt file. Use repeatedly.')
else:
group.add_argument('-model', required=True,
help='Path to model .pt file')

group = parser.add_argument_group('Data')
group.add_argument('-data_type', default="text",
Expand Down
16 changes: 12 additions & 4 deletions onmt/translate/Translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import onmt.translate.Beam
import onmt.io
import onmt.opts
import onmt.modules.Ensemble


def make_translator(opt, report_score=True, out_file=None):
def make_translator(opt, report_score=True, out_file=None, use_ensemble=False):
if out_file is None:
out_file = codecs.open(opt.output, 'w', 'utf-8')

Expand All @@ -24,8 +25,12 @@ def make_translator(opt, report_score=True, out_file=None):
onmt.opts.model_opts(dummy_parser)
dummy_opt = dummy_parser.parse_known_args([])[0]

fields, model, model_opt = \
onmt.ModelConstructor.load_test_model(opt, dummy_opt.__dict__)
if use_ensemble:
fields, model, model_opt = \
onmt.modules.Ensemble.load_test_model(opt, dummy_opt.__dict__)
else:
fields, model, model_opt = \
onmt.ModelConstructor.load_test_model(opt, dummy_opt.__dict__)

scorer = onmt.translate.GNMTGlobalScorer(opt.alpha,
opt.beta,
Expand Down Expand Up @@ -282,7 +287,7 @@ def unbottle(m):
src_map = rvar(batch.src_map.data) \
if data_type == 'text' and self.copy_attn else None
if isinstance(memory_bank, tuple):
memory_bank = tuple(rvar(x) for x in memory_bank)
memory_bank = tuple(rvar(x.data) for x in memory_bank)
else:
memory_bank = rvar(memory_bank.data)
memory_lengths = src_lengths.repeat(beam_size)
Expand Down Expand Up @@ -394,6 +399,9 @@ def _run_target(self, batch, data):
return gold_scores

def _report_score(self, name, score_total, words_total):
if words_total == 0:
print("%s No words predicted" % (name,))
return
print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total,
name, math.exp(-score_total / words_total)))
Expand Down