9
9
from utils import batch_variable
10
10
from seqeval .metrics import accuracy_score , classification_report , f1_score
11
11
12
+
12
13
def train (model , train_loader , dev_loader , config , vocab ):
13
14
14
- loss_all = np .array ([], dtype = float )
15
- label_all = np .array ([], dtype = float )
16
- predict_all = np .array ([], dtype = float )
17
15
dev_best_f1 = float ('-inf' )
18
-
16
+ avg_loss = []
19
17
optimizer = optim .AdamW (params = model .parameters (), lr = config .lr )
20
18
for epoch in range (0 , config .epochs ):
19
+ train_right , train_total = 0 , 0
21
20
for batch_idx , batch_data in enumerate (train_loader ):
22
21
model .train () #训练模型
23
22
word_ids , label_ids , label_mask = batch_variable (batch_data , vocab , config )
24
- loss , label_predict = model (word_ids , label_ids , label_mask )
23
+ loss , predicts = model (word_ids , label_ids , label_mask )
25
24
26
- loss_all = np .append (loss_all , loss .data .item ())
27
- label_all = np .append (label_all , label_ids .data .cpu ().numpy ())
28
- predict_all = np .append (predict_all , label_predict .data .cpu ().numpy ())
29
- acc = accuracy_score (predict_all , label_all )
25
+ avg_loss .append (loss .data .item ())
26
+
27
+ batch_right = ((predicts == label_ids ) * label_mask ).sum ().item ()
28
+ batch_total = label_mask .sum ().item ()
29
+ train_right += batch_right
30
+ train_total += batch_total
30
31
31
32
optimizer .zero_grad ()
32
33
loss .backward ()
33
34
optimizer .step ()
34
35
36
+
35
37
if batch_idx % 10 == 0 :
36
- print ("Epoch:{}--------Iter:{}--------train_loss:{:.3f}--------train_acc:{:.3f}" .format (epoch + 1 , batch_idx + 1 , loss_all . mean (), acc ))
38
+ print ("Epoch:{}--------Iter:{}--------train_loss:{:.3f}--------train_acc:{:.3f}" .format (epoch + 1 , batch_idx + 1 , np . array ( avg_loss ). mean (), train_right / train_total ))
37
39
dev_loss , dev_acc , dev_f1 , dev_report = evaluate (model , dev_loader , config , vocab )
38
- msg = "Dev Loss:{}--------Dev Acc:{}--------Dev F1:{}"
40
+ msg = "Dev Loss:{:.3f }--------Dev Acc:{:.3f }--------Dev F1:{:.3f }"
39
41
print (msg .format (dev_loss , dev_acc , dev_f1 ))
40
- print ("Dev Report" )
41
42
print (dev_report )
42
43
43
44
if dev_best_f1 < dev_f1 :
44
45
dev_best_f1 = dev_f1
45
46
torch .save (model .state_dict (), config .save_path )
46
47
print ("***************************** Save Model *****************************" )
47
48
48
- def evaluate (config , model , dev_loader , vocab , output_dict = False ):
49
+ def evaluate (model , one_loader , config , vocab , output_dict = False ):
49
50
model .eval () #评价模式
50
- loss_all = np . array ([], dtype = float )
51
+ loss_total = 0
51
52
predict_all = []
52
53
label_all = []
53
54
with torch .no_grad ():
54
- for batch_idx , batch_data in enumerate (dev_loader ):
55
+ for batch_idx , batch_data in enumerate (one_loader ):
55
56
word_ids , label_ids , label_mask = batch_variable (batch_data , vocab , config )
56
- loss , label_predict = model (word_ids , label_ids , label_mask )
57
+ loss , predicts = model (word_ids , label_ids , label_mask )
58
+
59
+ loss_total = loss_total + loss
60
+
61
+ for i , sen_mask in enumerate (label_mask ):
62
+ for j , word_mask in enumerate (sen_mask ):
63
+ if word_mask .item () == False :
64
+ predicts [i ][j ] = 0
65
+ labels_list = []
66
+ for index_i , ids in enumerate (label_ids ):
67
+ labels_list .append ([config .id2label [id .cpu ().item ()] for index_j , id in enumerate (ids )])
68
+ predicts_list = []
69
+ for index_i , pres in enumerate (predicts ):
70
+ predicts_list .append ([config .id2label [pre .cpu ().item ()] for index_j , pre in enumerate (pres )])
71
+
72
+ label_all += labels_list
73
+ predict_all += predicts_list
57
74
58
- loss_all = np .append (loss_all , loss .data .item ())
59
- predict_all .append (label_predict .data )
60
- label_all .append (label_ids .data )
61
75
acc = accuracy_score (label_all , predict_all )
62
- f1 = f1_score (label_all , predict_all , average = 'macro ' )
76
+ f1 = f1_score (label_all , predict_all , average = 'micro ' )
63
77
report = classification_report (label_all , predict_all , digits = 3 , output_dict = output_dict )
64
78
65
- return loss . mean ( ), acc , f1 , report
79
+ return loss_total / len ( one_loader ), acc , f1 , report
0 commit comments