-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
179 lines (144 loc) · 7.92 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import time
import random
import argparse
import os
import json
import numpy as np
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.nn import functional as F
from torch.autograd import Variable
from torch import autograd
from my_transformers.transformers import AdamW
from my_transformers.transformers import BertConfig, BertModel, BertTokenizer
from model import BERT, MyModel, MyModel_Clone
from dataloader import FewRel, get_dataloader
from train import train_one_batch, train_q, zero_grad
def test_model(mymodel, mymodel_clone, args):
n_way_k_shot = str(args.N) + '-way-' + str(args.K) + '-shot'
print('Start validating ' + n_way_k_shot)
cuda = torch.cuda.is_available()
if cuda:
mymodel = mymodel.cuda()
mymodel_clone = mymodel_clone.cuda()
data_loader = {}
# data_loader['train'] = get_dataloader(args.train, args.class_name_file, args.N, args.K, args.L, args.noise_rate)
data_loader['val'] = get_dataloader(args.val, args.class_name_file, args.N, args.K, args.L, args.noise_rate)
# data_loader['test'] = get_dataloader(args.test, args.class_name_file, args.N, args.K, args.L, args.noise_rate)
optim_params = [{'params': mymodel.coder.parameters(), 'lr': 1e-8}]
optim_params.append({'params': mymodel.fc.parameters(), 'lr': 1e-8})
optim_params.append({'params': mymodel.mlp.parameters(), 'lr': 1e-8})
meta_optimizer = AdamW(optim_params, lr=1)
meta_loss_final = 0.0
accs=0.0
val_iter = args.test_iter
for it in range(val_iter):
meta_loss = 0.0
mymodel.eval()
class_name, support, support_label, query, query_label = next(data_loader['val'])
if cuda:
support_label, query_label = support_label.cuda(), query_label.cuda()
'''First Step'''
loss_s, right_s, query1, class_name1 = train_one_batch(args, class_name, support, support_label, query, query_label, mymodel,
args.task_lr, it)
zero_grad(mymodel.parameters())
grads_fc = autograd.grad(loss_s, mymodel.fc.parameters(), retain_graph=True)
grads_mlp = autograd.grad(loss_s, mymodel.mlp.parameters())
fast_weights_fc, orderd_params = mymodel.cloned_fc_dict(), OrderedDict()
fast_weights_mlp = mymodel.cloned_mlp_dict()
for (key, val), grad in zip(mymodel.fc.named_parameters(), grads_fc):
fast_weights_fc[key] = orderd_params['fc.' + key] = val - args.task_lr * grad
for (key, val), grad in zip(mymodel.mlp.named_parameters(), grads_mlp):
fast_weights_mlp[key] = orderd_params['mlp.' + key] = val - args.task_lr * grad
name_list = []
for name in mymodel_clone.state_dict():
name_list.append(name)
for name in orderd_params:
if name in name_list:
mymodel_clone.state_dict()[name].copy_(orderd_params[name])
for _ in range(10-1):
'''2-10th Step'''
loss_s, right_s, query1, class_name1 = train_one_batch(args, class_name, support, support_label, query,
query_label, mymodel_clone,
args.task_lr, it)
zero_grad(mymodel_clone.parameters())
grads_fc = autograd.grad(loss_s, mymodel_clone.fc.parameters(), retain_graph=True)
grads_mlp = autograd.grad(loss_s, mymodel_clone.mlp.parameters())
fast_weights_fc, orderd_params = mymodel_clone.cloned_fc_dict(), OrderedDict()
fast_weights_mlp = mymodel_clone.cloned_mlp_dict()
for (key, val), grad in zip(mymodel_clone.fc.named_parameters(), grads_fc):
fast_weights_fc[key] = orderd_params['fc.' + key] = val - args.task_lr * grad
for (key, val), grad in zip(mymodel_clone.mlp.named_parameters(), grads_mlp):
fast_weights_mlp[key] = orderd_params['mlp.' + key] = val - args.task_lr * grad
name_list = []
for name in mymodel_clone.state_dict():
name_list.append(name)
for name in orderd_params:
if name in name_list:
mymodel_clone.state_dict()[name].copy_(orderd_params[name])
# -----在Query上计算loss和acc-------
loss_q, right_q = train_q(args, class_name1, query1, query_label, mymodel_clone)
meta_loss = meta_loss + loss_q
meta_loss_final += loss_q
accs += right_q
meta_optimizer.zero_grad()
meta_loss.backward()
if (it+1) % 100 == 0:
print('step: {0:4} | test_loss:{1:3.6f}, test_accuracy: {2:3.2f}%'.format(it+1, meta_loss_final/(it+1), 100*accs/(it+1)))
torch.cuda.empty_cache()
return accs/val_iter, meta_loss_final/val_iter
def main(args):
print('----------------------------------------------------')
print("{}-way-{}-shot Few-Shot Relation Classification".format(args.N, args.K))
print("Model: {}".format(args.Model))
print("config:", args)
print('----------------------------------------------------')
start_time = time.time()
mymodel = MyModel(args)
mymodel_clone = MyModel_Clone(args)
best_acc = 0.0
best_loss = 0.0
for file_name in os.listdir('model_checkpoint'):
if 'isNPM.tar' in file_name:
model_file = 'model_checkpoint/' + file_name
mymodel.load_state_dict(torch.load(model_file))
acc, loss = test_model(mymodel, mymodel_clone, args)
print('model_name:', model_file)
print('[TEST] | loss: {0:2.6f}, accuracy: {1:2.2f}%'.format(loss, acc * 100))
if acc > best_acc:
best_acc = acc
best_loss = loss
best_model_file = model_file
print('best_model_name:', best_model_file)
print('best_loss:', best_loss)
print('best_acc:', best_acc)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--Model', help='Model_name', default='Linear')
parser.add_argument('--train', help='train file', default='data/FewRel1.0/train_wiki.json')
parser.add_argument('--val', help='val file', default='data/FewRel1.0/val_wiki.json')
parser.add_argument('--test', help='test file', default='data/FewRel1.0/val_wiki.json')
parser.add_argument('--class_name_file', help='class name file', default='data/FewRel1.0/pid2name.json')
parser.add_argument('--seed', type=int, help='seed', default=15)
parser.add_argument('--max_length', type=int, help='max length', default=30)
parser.add_argument('--Train_iter', type=int, help='number of iters in training', default=100000)
parser.add_argument('--Val_iter', type=int, help='number of iters in validing', default=1)
parser.add_argument('--Test_update_step', type=int, help='number of adaptation steps', default=10)
parser.add_argument('--B', type=int, help='batch number', default=1)
parser.add_argument('--N', type=int, help='N way', default=5)
parser.add_argument('--K', type=int, help='K shot', default=1)
parser.add_argument('--L', type=int, help='number of query per class', default=5)
parser.add_argument('--noise_rate', type=int, help='noise rate, value range 0 to 10', default=0)
parser.add_argument('--task_lr', type=int, help='Task learning rate(里层)', default=1e-1)
parser.add_argument('--meta_lr', type=int, help='Meta learning rate(外层)', default=1e-3)
parser.add_argument('--ITT', type=int, help='Increasing Training Tasks', default=True)
parser.add_argument('--NPM_Loss', type=int, help='AUX Loss, N-pair-ms loss', default=False)
parser.add_argument('--lam', type=int, help='the importance if AUX Loss', default=0.2)
parser.add_argument('--SW', type=bool, help='the weights of support instances', default=False)
# test.py专用
parser.add_argument('--test_iter', type=int, help='test iter', default=500)
args = parser.parse_args()
main(args)