-
Notifications
You must be signed in to change notification settings - Fork 1k
/
model.py
112 lines (98 loc) · 4.62 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import wandb
import hydra
import numpy as np
import pandas as pd
import torchmetrics
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification
from omegaconf import OmegaConf, DictConfig
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
class ColaModel(pl.LightningModule):
def __init__(self, model_name="google/bert_uncased_L-2_H-128_A-2", lr=3e-5):
super(ColaModel, self).__init__()
self.save_hyperparameters()
self.bert = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2
)
self.num_classes = 2
self.train_accuracy_metric = torchmetrics.Accuracy()
self.val_accuracy_metric = torchmetrics.Accuracy()
self.f1_metric = torchmetrics.F1(num_classes=self.num_classes)
self.precision_macro_metric = torchmetrics.Precision(
average="macro", num_classes=self.num_classes
)
self.recall_macro_metric = torchmetrics.Recall(
average="macro", num_classes=self.num_classes
)
self.precision_micro_metric = torchmetrics.Precision(average="micro")
self.recall_micro_metric = torchmetrics.Recall(average="micro")
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.bert(
input_ids=input_ids, attention_mask=attention_mask, labels=labels
)
return outputs
def training_step(self, batch, batch_idx):
outputs = self.forward(
batch["input_ids"], batch["attention_mask"], labels=batch["label"]
)
# loss = F.cross_entropy(logits, batch["label"])
preds = torch.argmax(outputs.logits, 1)
train_acc = self.train_accuracy_metric(preds, batch["label"])
self.log("train/loss", outputs.loss, prog_bar=True, on_epoch=True)
self.log("train/acc", train_acc, prog_bar=True, on_epoch=True)
return outputs.loss
def validation_step(self, batch, batch_idx):
labels = batch["label"]
outputs = self.forward(
batch["input_ids"], batch["attention_mask"], labels=batch["label"]
)
preds = torch.argmax(outputs.logits, 1)
# Metrics
valid_acc = self.val_accuracy_metric(preds, labels)
precision_macro = self.precision_macro_metric(preds, labels)
recall_macro = self.recall_macro_metric(preds, labels)
precision_micro = self.precision_micro_metric(preds, labels)
recall_micro = self.recall_micro_metric(preds, labels)
f1 = self.f1_metric(preds, labels)
# Logging metrics
self.log("valid/loss", outputs.loss, prog_bar=True, on_step=True)
self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True)
self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True)
self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True)
self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True)
self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True)
self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
return {"labels": labels, "logits": outputs.logits}
def validation_epoch_end(self, outputs):
labels = torch.cat([x["labels"] for x in outputs])
logits = torch.cat([x["logits"] for x in outputs])
preds = torch.argmax(logits, 1)
## There are multiple ways to track the metrics
# 1. Confusion matrix plotting using inbuilt W&B method
self.logger.experiment.log(
{
"conf": wandb.plot.confusion_matrix(
probs=logits.numpy(), y_true=labels.numpy()
)
}
)
# 2. Confusion Matrix plotting using scikit-learn method
# wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.numpy(), preds)})
# 3. Confusion Matric plotting using Seaborn
# data = confusion_matrix(labels.numpy(), preds.numpy())
# df_cm = pd.DataFrame(data, columns=np.unique(labels), index=np.unique(labels))
# df_cm.index.name = "Actual"
# df_cm.columns.name = "Predicted"
# plt.figure(figsize=(7, 4))
# plot = sns.heatmap(
# df_cm, cmap="Blues", annot=True, annot_kws={"size": 16}
# ) # font size
# self.logger.experiment.log({"Confusion Matrix": wandb.Image(plot)})
# self.logger.experiment.log(
# {"roc": wandb.plot.roc_curve(labels.numpy(), logits.numpy())}
# )
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"])