-
Notifications
You must be signed in to change notification settings - Fork 456
/
test.py
74 lines (57 loc) · 2.3 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
@author : Hyunwoong
@when : 2019-12-19
@homepage : https://github.com/gusdnd852
"""
import math
from collections import Counter
import numpy as np
from data import *
from models.model.transformer import Transformer
from util.bleu import get_bleu, idx_to_word
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = Transformer(src_pad_idx=src_pad_idx,
trg_pad_idx=trg_pad_idx,
trg_sos_idx=trg_sos_idx,
d_model=d_model,
enc_voc_size=enc_voc_size,
dec_voc_size=dec_voc_size,
max_len=max_len,
ffn_hidden=ffn_hidden,
n_head=n_heads,
n_layers=n_layers,
drop_prob=0.00,
device=device).to(device)
print(f'The model has {count_parameters(model):,} trainable parameters')
def test_model(num_examples):
iterator = test_iter
model.load_state_dict(torch.load("./saved/model-saved.pt"))
with torch.no_grad():
batch_bleu = []
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
output = model(src, trg[:, :-1])
total_bleu = []
for j in range(num_examples):
try:
src_words = idx_to_word(src[j], loader.source.vocab)
trg_words = idx_to_word(trg[j], loader.target.vocab)
output_words = output[j].max(dim=1)[1]
output_words = idx_to_word(output_words, loader.target.vocab)
print('source :', src_words)
print('target :', trg_words)
print('predicted :', output_words)
print()
bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
total_bleu.append(bleu)
except:
pass
total_bleu = sum(total_bleu) / len(total_bleu)
print('BLEU SCORE = {}'.format(total_bleu))
batch_bleu.append(total_bleu)
batch_bleu = sum(batch_bleu) / len(batch_bleu)
print('TOTAL BLEU SCORE = {}'.format(batch_bleu))
if __name__ == '__main__':
test_model(num_examples=batch_size)