-
Notifications
You must be signed in to change notification settings - Fork 23
/
test.py
73 lines (63 loc) · 2.49 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# script to develop a toy example
# author: satwik kottur
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import itertools, pdb, random, json
import numpy as np
from chatbots import Team
from dataloader import Dataloader
import sys
sys.path.append('../')
from utilities import saveResultPage
#------------------------------------------------------------------------
# load experiment and model
#------------------------------------------------------------------------
if len(sys.argv) < 2:
print('Wrong usage:')
print('python test.py <modelPath>')
sys.exit(0)
# load and compute on test
loadPath = sys.argv[1]
print('Loading model from: %s' % loadPath)
loaded = torch.load(loadPath)
#------------------------------------------------------------------------
# build dataset, load agents
#------------------------------------------------------------------------
params = loaded['params']
data = Dataloader(params)
team = Team(params)
team.loadModel(loaded)
team.evaluate()
#------------------------------------------------------------------------
# test agents
#------------------------------------------------------------------------
dtypes = ['train', 'test']
for dtype in dtypes:
# evaluate on the train dataset, using greedy policy
images, tasks, labels = data.getCompleteData(dtype)
# forward pass
preds, _, talk = team.forward(Variable(images), Variable(tasks), True)
# compute accuracy for first, second and both attributes
firstMatch = preds[0].data == labels[:, 0].long()
secondMatch = preds[1].data == labels[:, 1].long()
matches = firstMatch & secondMatch
atleastOne = firstMatch | secondMatch
# compute accuracy
firstAcc = 100 * torch.mean(firstMatch.float())
secondAcc = 100 * torch.mean(secondMatch.float())
atleastAcc = 100 * torch.mean(atleastOne.float())
accuracy = 100 * torch.mean(matches.float())
print('\nOverall accuracy [%s]: %.2f (f: %.2f s: %.2f, atleast: %.2f)'\
% (dtype, accuracy, firstAcc, secondAcc, atleastAcc))
# pretty print
talk = data.reformatTalk(talk, preds, images, tasks, labels)
if 'final' in loadPath:
savePath = loadPath.replace('final', 'chatlog-'+dtype)
elif 'inter' in loadPath:
savePath = loadPath.replace('inter', 'chatlog-'+dtype)
savePath = savePath.replace('tar', 'json')
print('Saving conversations: %s' % savePath)
with open(savePath, 'w') as fileId: json.dump(talk, fileId)
saveResultPage(savePath)