Skip to content

Commit

Permalink
Refactoring Translator and introducing make_translator
Browse files Browse the repository at this point in the history
  • Loading branch information
pltrdy committed Apr 6, 2018
1 parent e46dee4 commit b8d8072
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 86 deletions.
19 changes: 12 additions & 7 deletions onmt/translate/TranslationServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import codecs
import json
import threading

from onmt.translate.Translator import make_translator

import onmt
import onmt.opts
import onmt.translate
Expand Down Expand Up @@ -209,9 +212,9 @@ def load(self):
timer.start()
self.out_file = io.StringIO()
try:
self.translator = onmt.translate.Translator(self.opt,
report_score=False,
out_file=self.out_file)
self.translator = make_translator(self.opt,
report_score=False,
out_file=self.out_file)
except RuntimeError as e:
raise ServerModelError("Runtime Error: %s" % str(e))

Expand Down Expand Up @@ -279,14 +282,16 @@ def run(self, inputs):
f.write(self.maybe_tokenize(line) + "\n")
timer.tick(name="writing")
try:
self.translator.translate(None, src_path, None)
self.translator.translate(None, src_path, None,
self.opt.batch_size)
except RuntimeError as e:
raise ServerModelError("Runtime Error: %s" % str(e))

timer.tick(name="translation")
print("Using model #%d\t%d inputs (%d subsegment)\ttranslation time: %f" %
(self.model_id, len(subsegment), sscount,
timer.times['translation']))
print("""Using model #%d\t%d inputs (%d subsegment)
\ttranslation time: %f""" % (self.model_id, len(subsegment),
sscount,
timer.times['translation']))
self.reset_unload_timer()
results = self.out_file.getvalue().split("\n")
results = ['\n'.join([self.maybe_detokenize(_)
Expand Down
179 changes: 116 additions & 63 deletions onmt/translate/Translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,37 @@
import onmt.opts


def make_translator(opt, report_score=True, out_file=None):
if out_file is None:
out_file = codecs.open(opt.output, 'w', 'utf-8')

if opt.gpu > -1:
torch.cuda.set_device(opt.gpu)

dummy_parser = argparse.ArgumentParser(description='train.py')
onmt.opts.model_opts(dummy_parser)
dummy_opt = dummy_parser.parse_known_args([])[0]

fields, model, model_opt = \
onmt.ModelConstructor.load_test_model(opt, dummy_opt.__dict__)

scorer = onmt.translate.GNMTGlobalScorer(opt.alpha,
opt.beta,
opt.coverage_penalty,
opt.length_penalty)

kwargs = {k: getattr(opt, k)
for k in ["beam_size", "n_best", "max_length", "min_length",
"stepwise_penalty", "block_ngram_repeat",
"ignore_when_blocking", "dump_beam",
"data_type", "replace_unk", "gpu", "verbose"]}

translator = Translator(model, fields, global_scorer=scorer,
out_file=out_file, report_score=report_score,
copy_attn=model_opt.copy_attn, **kwargs)
return translator


class Translator(object):
"""
Uses a model to translate a batch of sentences.
Expand All @@ -31,45 +62,63 @@ class Translator(object):
cuda (bool): use cuda
beam_trace (bool): trace beam search for debugging
"""
def __init__(self, opt, report_score=True, out_file=None):
self.opt = opt
self.report_score = report_score

dummy_parser = argparse.ArgumentParser(description='train.py')
onmt.opts.model_opts(dummy_parser)
dummy_opt = dummy_parser.parse_known_args([])[0]

self.opt.cuda = opt.gpu > -1
if self.opt.cuda:
torch.cuda.set_device(self.opt.gpu)

fields, model, model_opt = \
onmt.ModelConstructor.load_test_model(opt, dummy_opt.__dict__)

# File to write sentences to.
self.out_file = out_file
if out_file is None:
self.out_file = codecs.open(opt.output, 'w', 'utf-8')

# Scorer
self.global_scorer = onmt.translate.GNMTGlobalScorer(
self.opt.alpha, self.opt.beta, self.opt.coverage_penalty,
self.opt.length_penalty)
def __init__(self,
model,
fields,
beam_size,
n_best=1,
max_length=100,
global_scorer=None,
copy_attn=False,
gpu=False,
dump_beam="",
min_length=0,
stepwise_penalty=False,
block_ngram_repeat=0,
ignore_when_blocking=[],
sample_rate='16000',
window_size=.02,
window_stride=.01,
window='hamming',
use_filter_pred=False,
data_type="text",
replace_unk=False,
report_score=True,
report_bleu=False,
report_rouge=False,
verbose=False,
out_file=None):
self.gpu = gpu
self.cuda = gpu > -1

self.model = model
self.fields = fields
self.n_best = opt.n_best
self.max_length = opt.max_length
self.copy_attn = model_opt.copy_attn
self.beam_size = opt.beam_size
self.cuda = opt.cuda
self.min_length = opt.min_length
self.stepwise_penalty = opt.stepwise_penalty
self.beam_trace = opt.dump_beam != ""
self.block_ngram_repeat = opt.block_ngram_repeat
self.ignore_when_blocking = set(opt.ignore_when_blocking)
self.n_best = n_best
self.max_length = max_length
self.global_scorer = global_scorer
self.copy_attn = copy_attn
self.beam_size = beam_size
self.min_length = min_length
self.stepwise_penalty = stepwise_penalty
self.dump_beam = dump_beam
self.block_ngram_repeat = block_ngram_repeat
self.ignore_when_blocking = set(ignore_when_blocking)
self.sample_rate = sample_rate
self.window_size = window_size
self.window_stride = window_stride
self.window = window
self.use_filter_pred = use_filter_pred
self.replace_unk = replace_unk
self.data_type = data_type
self.verbose = verbose
self.out_file = out_file
self.report_score = report_score
self.report_bleu = report_bleu
self.report_rouge = report_rouge

# for debugging
self.beam_trace = self.dump_beam != ""
self.beam_accum = None
if self.beam_trace:
self.beam_accum = {
Expand All @@ -78,26 +127,26 @@ def __init__(self, opt, report_score=True, out_file=None):
"scores": [],
"log_probs": []}

def translate(self, src_dir, src_path, tgt_path):
def translate(self, src_dir, src_path, tgt_path, batch_size):
data = onmt.io.build_dataset(self.fields,
self.opt.data_type,
self.data_type,
src_path,
tgt_path,
src_dir=src_dir,
sample_rate=self.opt.sample_rate,
window_size=self.opt.window_size,
window_stride=self.opt.window_stride,
window=self.opt.window,
use_filter_pred=False)
sample_rate=self.sample_rate,
window_size=self.window_size,
window_stride=self.window_stride,
window=self.window,
use_filter_pred=self.use_filter_pred)

data_iter = onmt.io.OrderedIterator(
dataset=data, device=self.opt.gpu,
batch_size=self.opt.batch_size, train=False, sort=False,
dataset=data, device=self.gpu,
batch_size=batch_size, train=False, sort=False,
sort_within_batch=True, shuffle=False)

builder = onmt.translate.TranslationBuilder(
data, self.fields,
self.opt.n_best, self.opt.replace_unk, self.opt.tgt)
self.n_best, self.replace_unk, tgt_path)

# Statistics
counter = count(1)
Expand All @@ -111,16 +160,17 @@ def translate(self, src_dir, src_path, tgt_path):
for trans in translations:
pred_score_total += trans.pred_scores[0]
pred_words_total += len(trans.pred_sents[0])
if self.opt.tgt:
if tgt_path is not None:
gold_score_total += trans.gold_score
gold_words_total += len(trans.gold_sent) + 1

n_best_preds = [" ".join(pred)
for pred in trans.pred_sents[:self.opt.n_best]]
self.out_file.write('\n'.join(n_best_preds) + '\n')
for pred in trans.pred_sents[:self.n_best]]
self.out_file.write('\n'.join(n_best_preds))
self.out_file.write('\n')
self.out_file.flush()

if self.opt.verbose:
if self.verbose:
sent_number = next(counter)
output = trans.log(sent_number)
os.write(1, output.encode('utf-8'))
Expand All @@ -129,15 +179,15 @@ def translate(self, src_dir, src_path, tgt_path):
self._report_score('PRED', pred_score_total, pred_words_total)
if tgt_path is not None:
self._report_score('GOLD', gold_score_total, gold_words_total)
if self.opt.report_bleu:
self._report_bleu()
if self.opt.report_rouge:
self._report_rouge()
if self.report_bleu:
self._report_bleu(tgt_path)
if self.report_rouge:
self._report_rouge(tgt_path)

if self.opt.dump_beam:
if self.dump_beam:
import json
json.dump(self.translator.beam_accum,
codecs.open(self.opt.dump_beam, 'w', 'utf-8'))
codecs.open(self.dump_beam, 'w', 'utf-8'))

def translate_batch(self, batch, data):
"""
Expand Down Expand Up @@ -197,7 +247,7 @@ def unbottle(m):

enc_states, memory_bank = self.model.encoder(src, src_lengths)
dec_states = self.model.decoder.init_decoder_state(
src, memory_bank, enc_states)
src, memory_bank, enc_states)

if src_lengths is None:
src_lengths = torch.Tensor(batch_size).type_as(memory_bank.data)\
Expand Down Expand Up @@ -318,24 +368,27 @@ def _run_target(self, batch, data):

def _report_score(self, name, score_total, words_total):
print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total,
name, math.exp(-score_total / words_total)))
name, score_total / words_total,
name, math.exp(-score_total / words_total)))

def _report_bleu(self):
def _report_bleu(self, tgt_path):
import subprocess
path = os.path.split(os.path.realpath(__file__))[0]
print()

res = subprocess.check_output("perl %s/tools/multi-bleu.perl %s < %s"
% (path, self.opt.tgt, self.opt.output),
res = subprocess.check_output("perl %s/tools/multi-bleu.perl %s"
% (path, tgt_path, self.output),
stdin=self.out_file,
shell=True).decode("utf-8")

print(">> " + res.strip())

def _report_rouge(self):
def _report_rouge(self, tgt_path):
import subprocess
path = os.path.split(os.path.realpath(__file__))[0]
res = subprocess.check_output(
"python %s/tools/test_rouge.py -r %s -c %s"
% (path, self.opt.tgt, self.opt.output),
shell=True).decode("utf-8")
"python %s/tools/test_rouge.py -r %s -c STDIN"
% (path, tgt_path),
shell=True,
stdin=self.out_file).decode("utf-8")
print(res.strip())
32 changes: 19 additions & 13 deletions tools/test_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
import time
import pyrouge
import shutil
import sys


def test_rouge(cand_file, ref_file):
f_cand = open(cand_file, encoding="utf-8")
f_ref = open(ref_file, encoding="utf-8")
def test_rouge(cand, ref):
"""Calculate ROUGE scores of sequences passed as an iterator
e.g. a list of str, an open file, StringIO or even sys.stdin
"""
current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
tmp_dir = ".rouge-tmp-{}".format(current_time)
try:
if not os.path.isdir(tmp_dir):
os.mkdir(tmp_dir)
os.mkdir(tmp_dir + "/candidate")
os.mkdir(tmp_dir + "/reference")
candidates = [line.strip() for line in f_cand]
references = [line.strip() for line in f_ref]
candidates = [line.strip() for line in cand]
references = [line.strip() for line in ref]
assert len(candidates) == len(references)
cnt = len(candidates)
for i in range(cnt):
Expand All @@ -29,8 +31,6 @@ def test_rouge(cand_file, ref_file):
with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w",
encoding="utf-8") as f:
f.write(references[i])
f_cand.close()
f_ref.close()
r = pyrouge.Rouge155()
r.model_dir = tmp_dir + "/reference/"
r.system_dir = tmp_dir + "/candidate/"
Expand All @@ -45,17 +45,23 @@ def test_rouge(cand_file, ref_file):
shutil.rmtree(tmp_dir)


def rouge_results_to_str(results_dict):
return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format(
results_dict["rouge_1_f_score"] * 100,
results_dict["rouge_2_f_score"] * 100,
results_dict["rouge_3_f_score"] * 100,
results_dict["rouge_l_f_score"] * 100,
results_dict["rouge_su*_f_score"] * 100)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', type=str, default="candidate.txt",
help='candidate file')
parser.add_argument('-r', type=str, default="reference.txt",
help='reference file')
args = parser.parse_args()
if args.c.upper() == "STDIN":
args.c = sys.stdin
results_dict = test_rouge(args.c, args.r)
print(">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format(
results_dict["rouge_1_f_score"] * 100,
results_dict["rouge_2_f_score"] * 100,
results_dict["rouge_3_f_score"] * 100,
results_dict["rouge_l_f_score"] * 100,
results_dict["rouge_su*_f_score"] * 100))
print(rouge_results_to_str(results_dict))
7 changes: 4 additions & 3 deletions translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import division, unicode_literals
import argparse

from onmt.translate.Translator import make_translator

import onmt.io
import onmt.translate
import onmt
Expand All @@ -11,9 +13,8 @@


def main(opt):
translator = onmt.translate.Translator(opt,
report_score=True)
translator.translate(opt.src_dir, opt.src, opt.tgt)
translator = make_translator(opt, report_score=True)
translator.translate(opt.src_dir, opt.src, opt.tgt, opt.batch_size)


if __name__ == "__main__":
Expand Down

0 comments on commit b8d8072

Please sign in to comment.