Skip to content

Commit e22186f

Browse files
committed
pass encoder and decoder as args
1 parent fe3533a commit e22186f

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

evaluate.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from data import *
22
from model import *
33

4-
encoder = torch.load('seq2seq-encoder.pt')
5-
decoder = torch.load('seq2seq-decoder.pt')
4+
MIN_PROB = -0.1
65

76
# # Evaluating the trained model
87

9-
def evaluate(sentence, max_length=MAX_LENGTH):
8+
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
109
input_variable = input_lang.variable_from_sentence(sentence)
1110
input_length = input_variable.size()[0]
1211
encoder_hidden = encoder.init_hidden()
@@ -49,19 +48,23 @@ def evaluate(sentence, max_length=MAX_LENGTH):
4948
'is my light on'
5049
]
5150

52-
def evaluate_tests():
51+
def evaluate_tests(encoder, decoder, ):
5352
for test_sentence in test_sentences:
54-
command, prob, attn = evaluate(test_sentence)
55-
if prob > -0.05:
56-
print(prob, command)
53+
command, prob, attn = evaluate(encoder, decoder, test_sentence)
54+
if prob < MIN_PROB:
55+
print(test_sentence, '\n\t', prob, '???')
5756
else:
58-
print(prob, "UNKNOWN")
57+
print(test_sentence, '\n\t', prob, command)
5958

6059
if __name__ == '__main__':
6160
import sys
6261
input = sys.argv[1]
6362
print('input', input)
64-
command, prob, attn = evaluate(input)
63+
64+
encoder = torch.load('seq2seq-encoder.pt')
65+
decoder = torch.load('seq2seq-decoder.pt')
66+
67+
command, prob, attn = evaluate(encoder, decoder, input)
6568
if prob > -0.05:
6669
print(prob, command)
6770
else:

0 commit comments

Comments
 (0)