Skip to content

Commit acd0e7a

Browse files
added test
1 parent c16d15c commit acd0e7a

File tree

1 file changed

+46
-16
lines changed

1 file changed

+46
-16
lines changed

train.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from utils import process_all_files,load_GloVe,accuracy_cal
77
from model import GA_Reader
8-
from data_loader import DataLoader
8+
from data_loader import DataLoader,TestLoader
99

1010
def train(epochs,iterations,loader_train,loader_val,
1111
model,optimizer,loss_function):
@@ -36,20 +36,44 @@ def train(epochs,iterations,loader_train,loader_val,
3636

3737
def validate(loader_val,model,loss_function):
3838
model.eval()
39-
40-
doc,doc_char,doc_mask,query,query_char,query_mask, \
41-
char_type,char_type_mask,answer,cloze,cand, \
42-
cand_mask,qe_comm=loader_val.__load_next__()
43-
44-
output=model( doc,doc_char,doc_mask,query,query_char,query_mask,
45-
char_type,char_type_mask,answer,cloze,cand,
46-
cand_mask,qe_comm)
39+
return_loss=0
40+
accuracy=0
41+
42+
for _ in range(loader_val.examples//loader_val.batch_size):
43+
doc,doc_char,doc_mask,query,query_char,query_mask, \
44+
char_type,char_type_mask,answer,cloze,cand, \
45+
cand_mask,qe_comm=loader_val.__load_next__()
46+
47+
output=model( doc,doc_char,doc_mask,query,query_char,query_mask,
48+
char_type,char_type_mask,answer,cloze,cand,
49+
cand_mask,qe_comm)
50+
51+
accuracy+=accuracy_cal(output,answer)
52+
loss=loss_function(output,answer)
53+
return_loss+=loss.item()
4754

48-
accuracy=accuracy_cal(output,answer)
49-
loss=loss_function(output,answer)
55+
return_loss/=(loader_val.examples//loader_val.batch_size)
56+
accuracy=100*accuracy/loader_val.examples
5057

51-
return loss.item(),accuracy
58+
return return_loss,accuracy
5259

60+
def test(loader_test,model):
61+
model.eval()
62+
accuracy=0
63+
for _ in range(loader_test.examples//loader_test.batch_size):
64+
doc,doc_char,doc_mask,query,query_char,query_mask, \
65+
char_type,char_type_mask,answer,cloze,cand, \
66+
cand_mask,qe_comm=loader_test.__load_next__()
67+
68+
output=model( doc,doc_char,doc_mask,query,query_char,query_mask,
69+
char_type,char_type_mask,answer,cloze,cand,
70+
cand_mask,qe_comm)
71+
72+
accuracy+=accuracy_cal(output,answer)
73+
74+
accuracy=100*accuracy/loader_test.examples
75+
print('test accuracy=',accuracy)
76+
5377
def main(args):
5478
word_to_int,int_to_word,char_to_int,int_to_char, \
5579
training_data=process_all_files(args.train_file)
@@ -62,11 +86,17 @@ def main(args):
6286

6387
optimizer=optim.Adam(model.parameters(),lr=args.lr)
6488
data_loader_train=DataLoader(training_data[:args.training_size],args.batch_size)
65-
data_loader_validate=DataLoader(training_data[args.training_size:],args.batch_size)
89+
data_loader_validate=TestLoader(training_data[args.training_size:args. \
90+
training_size+args.dev_size],args.dev_size)
91+
data_loader_test=TestLoader(training_data[args. \
92+
training_size_args.dev_size:args. \
93+
training_size+args.dev_size+args.test_size],args.test_size)
6694

6795
train(args.epochs,args.iterations,data_loader_train,
6896
data_loader_validate,model,optimizer,loss_function)
6997

98+
test(data_loader_test,model)
99+
70100
def setup():
71101
parser=argparse.ArgumentParser('argument parser')
72102
parser.add_argument('--lr',type=float,default=0.00005)
@@ -82,9 +112,9 @@ def setup():
82112
parser.add_argument('--gru_layers',type=int,default=3)
83113
parser.add_argument('--embed_file',type=str,default=os.getcwd()+'/word2vec_glove.text')
84114
parser.add_argument('--train_file',type=str,default=os.getcwd()+'/train/')
85-
parser.add_argument('--dev_file',type=str,default=os.getcwd()+'/validation/')
86-
parser.add_argument('--test_file',type=str,default=os.getcwd()+'/test/')
87-
parser.add_argument('--training_size',type=int,default=380,298)
115+
parser.add_argument('--train_size',type=int,default=380298)
116+
parser.add_argument('--dev_size',type=int,default=3924)
117+
parser.add_argument('--test_size',type=int,default=3198)
88118

89119
args=parser.parse_args()
90120

0 commit comments

Comments
 (0)