forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranslate.py
executable file
·116 lines (93 loc) · 3.75 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python
from __future__ import division, unicode_literals
import os
import argparse
import math
import codecs
import torch
from itertools import count
import onmt.io
import onmt.translate
import onmt
import onmt.ModelConstructor
import onmt.modules
import opts
parser = argparse.ArgumentParser(
description='translate.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
opts.add_md_help_argument(parser)
opts.translate_opts(parser)
opt = parser.parse_args()
def main():
dummy_parser = argparse.ArgumentParser(description='train.py')
opts.model_opts(dummy_parser)
dummy_opt = dummy_parser.parse_known_args([])[0]
opt.cuda = opt.gpu > -1
if opt.cuda:
torch.cuda.set_device(opt.gpu)
# Load the model.
fields, model, model_opt = \
onmt.ModelConstructor.load_test_model(opt, dummy_opt.__dict__)
# File to write sentences to.
out_file = codecs.open(opt.output, 'w', 'utf-8')
# Test data
data = onmt.io.build_dataset(fields, opt.data_type,
opt.src, opt.tgt,
src_dir=opt.src_dir,
sample_rate=opt.sample_rate,
window_size=opt.window_size,
window_stride=opt.window_stride,
window=opt.window,
use_filter_pred=False)
test_data = onmt.io.OrderedIterator(
dataset=data, device=opt.gpu,
batch_size=opt.batch_size, train=False, sort=False,
shuffle=False)
# Translator
scorer = onmt.translate.GNMTGlobalScorer(opt.alpha, opt.beta)
translator = onmt.translate.Translator(model, fields,
beam_size=opt.beam_size,
n_best=opt.n_best,
global_scorer=scorer,
max_length=opt.max_sent_length,
copy_attn=model_opt.copy_attn,
cuda=opt.cuda,
beam_trace=opt.dump_beam != "")
builder = onmt.translate.TranslationBuilder(
data, translator.fields,
opt.n_best, opt.replace_unk, opt.tgt)
# Statistics
counter = count(1)
pred_score_total, pred_words_total = 0, 0
gold_score_total, gold_words_total = 0, 0
for batch in test_data:
batch_data = translator.translate_batch(batch, data)
translations = builder.from_batch(batch_data)
for trans in translations:
pred_score_total += trans.pred_scores[0]
pred_words_total += len(trans.pred_sents[0])
if opt.tgt:
gold_score_total += trans.gold_score
gold_words_total += len(trans.gold_sent)
n_best_preds = [" ".join(pred)
for pred in trans.pred_sents[:opt.n_best]]
out_file.write('\n'.join(n_best_preds))
out_file.write('\n')
out_file.flush()
if opt.verbose:
sent_number = next(counter)
output = trans.log(sent_number)
os.write(1, output.encode('utf-8'))
def report_score(name, score_total, words_total):
print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total,
name, math.exp(-score_total/words_total)))
report_score('PRED', pred_score_total, pred_words_total)
if opt.tgt:
report_score('GOLD', gold_score_total, gold_words_total)
if opt.dump_beam:
import json
json.dump(translator.beam_accum,
codecs.open(opt.dump_beam, 'w', 'utf-8'))
if __name__ == "__main__":
main()