5
5
6
6
from utils import process_all_files ,load_GloVe ,accuracy_cal
7
7
from model import GA_Reader
8
- from data_loader import DataLoader
8
+ from data_loader import DataLoader , TestLoader
9
9
10
10
def train (epochs ,iterations ,loader_train ,loader_val ,
11
11
model ,optimizer ,loss_function ):
@@ -36,20 +36,44 @@ def train(epochs,iterations,loader_train,loader_val,
36
36
37
37
def validate (loader_val ,model ,loss_function ):
38
38
model .eval ()
39
-
40
- doc ,doc_char ,doc_mask ,query ,query_char ,query_mask , \
41
- char_type ,char_type_mask ,answer ,cloze ,cand , \
42
- cand_mask ,qe_comm = loader_val .__load_next__ ()
43
-
44
- output = model ( doc ,doc_char ,doc_mask ,query ,query_char ,query_mask ,
45
- char_type ,char_type_mask ,answer ,cloze ,cand ,
46
- cand_mask ,qe_comm )
39
+ return_loss = 0
40
+ accuracy = 0
41
+
42
+ for _ in range (loader_val .examples // loader_val .batch_size ):
43
+ doc ,doc_char ,doc_mask ,query ,query_char ,query_mask , \
44
+ char_type ,char_type_mask ,answer ,cloze ,cand , \
45
+ cand_mask ,qe_comm = loader_val .__load_next__ ()
46
+
47
+ output = model ( doc ,doc_char ,doc_mask ,query ,query_char ,query_mask ,
48
+ char_type ,char_type_mask ,answer ,cloze ,cand ,
49
+ cand_mask ,qe_comm )
50
+
51
+ accuracy += accuracy_cal (output ,answer )
52
+ loss = loss_function (output ,answer )
53
+ return_loss += loss .item ()
47
54
48
- accuracy = accuracy_cal ( output , answer )
49
- loss = loss_function ( output , answer )
55
+ return_loss /= ( loader_val . examples // loader_val . batch_size )
56
+ accuracy = 100 * accuracy / loader_val . examples
50
57
51
- return loss . item () ,accuracy
58
+ return return_loss ,accuracy
52
59
60
+ def test (loader_test ,model ):
61
+ model .eval ()
62
+ accuracy = 0
63
+ for _ in range (loader_test .examples // loader_test .batch_size ):
64
+ doc ,doc_char ,doc_mask ,query ,query_char ,query_mask , \
65
+ char_type ,char_type_mask ,answer ,cloze ,cand , \
66
+ cand_mask ,qe_comm = loader_test .__load_next__ ()
67
+
68
+ output = model ( doc ,doc_char ,doc_mask ,query ,query_char ,query_mask ,
69
+ char_type ,char_type_mask ,answer ,cloze ,cand ,
70
+ cand_mask ,qe_comm )
71
+
72
+ accuracy += accuracy_cal (output ,answer )
73
+
74
+ accuracy = 100 * accuracy / loader_test .examples
75
+ print ('test accuracy=' ,accuracy )
76
+
53
77
def main (args ):
54
78
word_to_int ,int_to_word ,char_to_int ,int_to_char , \
55
79
training_data = process_all_files (args .train_file )
@@ -62,11 +86,17 @@ def main(args):
62
86
63
87
optimizer = optim .Adam (model .parameters (),lr = args .lr )
64
88
data_loader_train = DataLoader (training_data [:args .training_size ],args .batch_size )
65
- data_loader_validate = DataLoader (training_data [args .training_size :],args .batch_size )
89
+ data_loader_validate = TestLoader (training_data [args .training_size :args . \
90
+ training_size + args .dev_size ],args .dev_size )
91
+ data_loader_test = TestLoader (training_data [args . \
92
+ training_size_args .dev_size :args . \
93
+ training_size + args .dev_size + args .test_size ],args .test_size )
66
94
67
95
train (args .epochs ,args .iterations ,data_loader_train ,
68
96
data_loader_validate ,model ,optimizer ,loss_function )
69
97
98
+ test (data_loader_test ,model )
99
+
70
100
def setup ():
71
101
parser = argparse .ArgumentParser ('argument parser' )
72
102
parser .add_argument ('--lr' ,type = float ,default = 0.00005 )
@@ -82,9 +112,9 @@ def setup():
82
112
parser .add_argument ('--gru_layers' ,type = int ,default = 3 )
83
113
parser .add_argument ('--embed_file' ,type = str ,default = os .getcwd ()+ '/word2vec_glove.text' )
84
114
parser .add_argument ('--train_file' ,type = str ,default = os .getcwd ()+ '/train/' )
85
- parser .add_argument ('--dev_file ' ,type = str ,default = os . getcwd () + '/validation/' )
86
- parser .add_argument ('--test_file ' ,type = str ,default = os . getcwd () + '/test/' )
87
- parser .add_argument ('--training_size ' ,type = int ,default = 380 , 298 )
115
+ parser .add_argument ('--train_size ' ,type = int ,default = 380298 )
116
+ parser .add_argument ('--dev_size ' ,type = int ,default = 3924 )
117
+ parser .add_argument ('--test_size ' ,type = int ,default = 3198 )
88
118
89
119
args = parser .parse_args ()
90
120
0 commit comments