3
3
import platform
4
4
from glob import glob
5
5
6
+ import numpy as np
7
+ from seqeval import metrics as seqeval_metrics
8
+
6
9
import torch
7
10
import pytorch_lightning as pl
8
11
from pytorch_lightning .callbacks import EarlyStopping
@@ -53,30 +56,31 @@ def validation_step(self, batch, batch_idx):
53
56
return result
54
57
55
58
def validation_epoch_end (self , outputs ):
56
- preds = torch .cat ([x ["preds" ] for x in outputs ])
57
- labels = torch .cat ([x ["labels" ] for x in outputs ])
59
+ preds = torch .cat ([x ["preds" ] for x in outputs ]). detach (). cpu (). numpy ()
60
+ labels = torch .cat ([x ["labels" ] for x in outputs ]). detach (). cpu (). numpy ()
58
61
loss = torch .stack ([x ["loss" ] for x in outputs ]).mean ()
59
62
60
63
# remove padding
64
+ out_label_list = [[] for _ in range (labels .shape [0 ])]
65
+ preds_list = [[] for _ in range (preds .shape [0 ])]
66
+ assert (len (out_label_list ) == len (preds_list )), "Prediction and Label are not matched."
67
+
61
68
from torch .nn import CrossEntropyLoss
62
69
pad_token_label_id = CrossEntropyLoss ().ignore_index
63
70
64
- final_label_list = []
65
- final_preds_list = []
71
+ label_map = { i : label for i , label in enumerate ( list ( self . label_vocab . keys ()))}
72
+
66
73
for i in range (labels .shape [0 ]):
67
74
for j in range (labels .shape [1 ]):
68
75
if labels [i , j ] != pad_token_label_id :
69
- final_label_list .append (labels [i ][j ])
70
- final_preds_list .append (preds [i ][j ])
71
-
72
- final_label_list = torch .LongTensor (final_label_list )
73
- final_preds_list = torch .LongTensor (final_preds_list )
76
+ out_label_list [i ].append (label_map [labels [i ][j ]])
77
+ preds_list [i ].append (label_map [preds [i ][j ]])
74
78
75
- correct_count = torch . sum ( final_label_list == final_preds_list )
76
- val_acc = correct_count . float () / float ( len ( final_label_list ) )
79
+ # metrics - F1
80
+ val_f1 = seqeval_metrics . f1_score ( out_label_list , preds_list )
77
81
78
82
self .log ("val_loss" , loss , prog_bar = True )
79
- self .log ("val_acc " , val_acc , prog_bar = True )
83
+ self .log ("val_f1 " , val_f1 , prog_bar = True )
80
84
return loss
81
85
82
86
def test_step (self , batch , batch_idx ):
@@ -91,31 +95,38 @@ def test_step(self, batch, batch_idx):
91
95
return result
92
96
93
97
def test_epoch_end (self , outputs ):
94
- preds = torch .cat ([x ["preds" ] for x in outputs ])
95
- labels = torch .cat ([x ["labels" ] for x in outputs ])
98
+ preds = torch .cat ([x ["preds" ] for x in outputs ]). detach (). cpu (). numpy ()
99
+ labels = torch .cat ([x ["labels" ] for x in outputs ]). detach (). cpu (). numpy ()
96
100
97
101
# remove padding
102
+ out_label_list = [[] for _ in range (labels .shape [0 ])]
103
+ preds_list = [[] for _ in range (preds .shape [0 ])]
104
+ assert (len (out_label_list ) == len (preds_list )), "Prediction and Label are not matched."
105
+
98
106
from torch .nn import CrossEntropyLoss
99
107
pad_token_label_id = CrossEntropyLoss ().ignore_index
100
108
101
- final_label_list = []
102
- final_preds_list = []
109
+ label_map = { i : label for i , label in enumerate ( list ( self . label_vocab . keys ()))}
110
+
103
111
for i in range (labels .shape [0 ]):
104
112
for j in range (labels .shape [1 ]):
105
113
if labels [i , j ] != pad_token_label_id :
106
- final_label_list .append (labels [i ][j ])
107
- final_preds_list .append (preds [i ][j ])
114
+ out_label_list [ i ] .append (label_map [ labels [i ][j ] ])
115
+ preds_list [ i ] .append (label_map [ preds [i ][j ] ])
108
116
109
- final_label_list = torch .LongTensor (final_label_list )
110
- final_preds_list = torch .LongTensor (final_preds_list )
117
+ # metrics - Precision, Recall, F1
118
+ result = {
119
+ "precision" : seqeval_metrics .precision_score (out_label_list , preds_list ),
120
+ "recall" : seqeval_metrics .recall_score (out_label_list , preds_list ),
121
+ "f1" : seqeval_metrics .f1_score (out_label_list , preds_list ),
122
+ }
111
123
112
- correct_count = torch . sum ( final_label_list == final_preds_list )
113
- test_acc = correct_count . float () / float ( len ( final_label_list ))
124
+ print ( )
125
+ print ( seqeval_metrics . classification_report ( out_label_list , preds_list ))
114
126
115
127
# 나중에 self.label_vocab을 이용해서 실제 태그로 바꾸고 text file에 예측 결과들 덤핑하는것도 짜야함!
116
128
117
- self .log ("val_acc" , test_acc , prog_bar = True )
118
- return test_acc
129
+ return result
119
130
120
131
def configure_optimizers (self ):
121
132
from transformers import AdamW
0 commit comments