forked from Kyubyong/transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
65 lines (49 loc) · 1.85 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
Inference
'''
import os
import tensorflow as tf
from data_load import get_batch
from model import Transformer
from hparams import Hparams
from utils import get_hypotheses, calc_bleu, postprocess, load_hparams
import logging
logging.basicConfig(level=logging.INFO)
logging.info("# hparams")
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
load_hparams(hp, hp.ckpt)
logging.info("# Prepare test batches")
test_batches, num_test_batches, num_test_samples = get_batch(hp.test1, hp.test1,
100000, 100000,
hp.vocab, hp.test_batch_size,
shuffle=False)
iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes)
xs, ys = iter.get_next()
test_init_op = iter.make_initializer(test_batches)
logging.info("# Load model")
m = Transformer(hp)
y_hat, _ = m.eval(xs, ys)
logging.info("# Session")
with tf.Session() as sess:
ckpt_ = tf.train.latest_checkpoint(hp.ckpt)
ckpt = hp.ckpt if ckpt_ is None else ckpt_ # None: ckpt is a file. otherwise dir.
saver = tf.train.Saver()
saver.restore(sess, ckpt)
sess.run(test_init_op)
logging.info("# get hypotheses")
hypotheses = get_hypotheses(num_test_batches, num_test_samples, sess, y_hat, m.idx2token)
logging.info("# write results")
model_output = ckpt.split("/")[-1]
if not os.path.exists(hp.testdir): os.makedirs(hp.testdir)
translation = os.path.join(hp.testdir, model_output)
with open(translation, 'w') as fout:
fout.write("\n".join(hypotheses))
logging.info("# calc bleu score and append it to translation")
calc_bleu(hp.test2, translation)