Skip to content

Commit 1ac9300

Browse files
committed
add case_ids
1 parent 543b4df commit 1ac9300

File tree

2 files changed

+36
-25
lines changed

2 files changed

+36
-25
lines changed

run_classification_pl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_epoch_end(self, outputs):
8484

8585
# 나중에 self.label_vocab을 이용해서 실제 태그로 바꾸고 text file에 예측 결과들 덤핑하는것도 짜야함!
8686

87-
self.log("val_acc", test_acc, prog_bar=True)
87+
self.log("test_acc", test_acc, prog_bar=True)
8888
return test_acc
8989

9090
def configure_optimizers(self):

run_ner_pl.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import platform
44
from glob import glob
55

6+
import numpy as np
7+
from seqeval import metrics as seqeval_metrics
8+
69
import torch
710
import pytorch_lightning as pl
811
from pytorch_lightning.callbacks import EarlyStopping
@@ -53,30 +56,31 @@ def validation_step(self, batch, batch_idx):
5356
return result
5457

5558
def validation_epoch_end(self, outputs):
56-
preds = torch.cat([x["preds"] for x in outputs])
57-
labels = torch.cat([x["labels"] for x in outputs])
59+
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
60+
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
5861
loss = torch.stack([x["loss"] for x in outputs]).mean()
5962

6063
# remove padding
64+
out_label_list = [[] for _ in range(labels.shape[0])]
65+
preds_list = [[] for _ in range(preds.shape[0])]
66+
assert (len(out_label_list) == len(preds_list)), "Prediction and Label are not matched."
67+
6168
from torch.nn import CrossEntropyLoss
6269
pad_token_label_id = CrossEntropyLoss().ignore_index
6370

64-
final_label_list = []
65-
final_preds_list = []
71+
label_map = {i: label for i, label in enumerate(list(self.label_vocab.keys()))}
72+
6673
for i in range(labels.shape[0]):
6774
for j in range(labels.shape[1]):
6875
if labels[i, j] != pad_token_label_id:
69-
final_label_list.append(labels[i][j])
70-
final_preds_list.append(preds[i][j])
71-
72-
final_label_list = torch.LongTensor(final_label_list)
73-
final_preds_list = torch.LongTensor(final_preds_list)
76+
out_label_list[i].append(label_map[labels[i][j]])
77+
preds_list[i].append(label_map[preds[i][j]])
7478

75-
correct_count = torch.sum(final_label_list == final_preds_list)
76-
val_acc = correct_count.float() / float(len(final_label_list))
79+
# metrics - F1
80+
val_f1 = seqeval_metrics.f1_score(out_label_list, preds_list)
7781

7882
self.log("val_loss", loss, prog_bar=True)
79-
self.log("val_acc", val_acc, prog_bar=True)
83+
self.log("val_f1", val_f1, prog_bar=True)
8084
return loss
8185

8286
def test_step(self, batch, batch_idx):
@@ -91,31 +95,38 @@ def test_step(self, batch, batch_idx):
9195
return result
9296

9397
def test_epoch_end(self, outputs):
94-
preds = torch.cat([x["preds"] for x in outputs])
95-
labels = torch.cat([x["labels"] for x in outputs])
98+
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
99+
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
96100

97101
# remove padding
102+
out_label_list = [[] for _ in range(labels.shape[0])]
103+
preds_list = [[] for _ in range(preds.shape[0])]
104+
assert(len(out_label_list) == len(preds_list)), "Prediction and Label are not matched."
105+
98106
from torch.nn import CrossEntropyLoss
99107
pad_token_label_id = CrossEntropyLoss().ignore_index
100108

101-
final_label_list = []
102-
final_preds_list = []
109+
label_map = {i: label for i, label in enumerate(list(self.label_vocab.keys()))}
110+
103111
for i in range(labels.shape[0]):
104112
for j in range(labels.shape[1]):
105113
if labels[i, j] != pad_token_label_id:
106-
final_label_list.append(labels[i][j])
107-
final_preds_list.append(preds[i][j])
114+
out_label_list[i].append(label_map[labels[i][j]])
115+
preds_list[i].append(label_map[preds[i][j]])
108116

109-
final_label_list = torch.LongTensor(final_label_list)
110-
final_preds_list = torch.LongTensor(final_preds_list)
117+
# metrics - Precision, Recall, F1
118+
result = {
119+
"precision": seqeval_metrics.precision_score(out_label_list, preds_list),
120+
"recall": seqeval_metrics.recall_score(out_label_list, preds_list),
121+
"f1": seqeval_metrics.f1_score(out_label_list, preds_list),
122+
}
111123

112-
correct_count = torch.sum(final_label_list == final_preds_list)
113-
test_acc = correct_count.float() / float(len(final_label_list))
124+
print()
125+
print(seqeval_metrics.classification_report(out_label_list, preds_list))
114126

115127
# 나중에 self.label_vocab을 이용해서 실제 태그로 바꾸고 text file에 예측 결과들 덤핑하는것도 짜야함!
116128

117-
self.log("val_acc", test_acc, prog_bar=True)
118-
return test_acc
129+
return result
119130

120131
def configure_optimizers(self):
121132
from transformers import AdamW

0 commit comments

Comments
 (0)