|
1 | 1 | from data import *
|
2 | 2 | from model import *
|
3 | 3 |
|
4 |
| -encoder = torch.load('seq2seq-encoder.pt') |
5 |
| -decoder = torch.load('seq2seq-decoder.pt') |
| 4 | +MIN_PROB = -0.1 |
6 | 5 |
|
7 | 6 | # # Evaluating the trained model
|
8 | 7 |
|
9 |
| -def evaluate(sentence, max_length=MAX_LENGTH): |
| 8 | +def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): |
10 | 9 | input_variable = input_lang.variable_from_sentence(sentence)
|
11 | 10 | input_length = input_variable.size()[0]
|
12 | 11 | encoder_hidden = encoder.init_hidden()
|
@@ -49,19 +48,23 @@ def evaluate(sentence, max_length=MAX_LENGTH):
|
49 | 48 | 'is my light on'
|
50 | 49 | ]
|
51 | 50 |
|
52 |
| -def evaluate_tests(): |
| 51 | +def evaluate_tests(encoder, decoder, ): |
53 | 52 | for test_sentence in test_sentences:
|
54 |
| - command, prob, attn = evaluate(test_sentence) |
55 |
| - if prob > -0.05: |
56 |
| - print(prob, command) |
| 53 | + command, prob, attn = evaluate(encoder, decoder, test_sentence) |
| 54 | + if prob < MIN_PROB: |
| 55 | + print(test_sentence, '\n\t', prob, '???') |
57 | 56 | else:
|
58 |
| - print(prob, "UNKNOWN") |
| 57 | + print(test_sentence, '\n\t', prob, command) |
59 | 58 |
|
60 | 59 | if __name__ == '__main__':
|
61 | 60 | import sys
|
62 | 61 | input = sys.argv[1]
|
63 | 62 | print('input', input)
|
64 |
| - command, prob, attn = evaluate(input) |
| 63 | + |
| 64 | + encoder = torch.load('seq2seq-encoder.pt') |
| 65 | + decoder = torch.load('seq2seq-decoder.pt') |
| 66 | + |
| 67 | + command, prob, attn = evaluate(encoder, decoder, input) |
65 | 68 | if prob > -0.05:
|
66 | 69 | print(prob, command)
|
67 | 70 | else:
|
|
0 commit comments