Skip to content

Commit

Permalink
added metric helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulDanielML committed Oct 4, 2022
1 parent 7140324 commit 2ba714d
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions helpers/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Dict, List, Union
from sklearn.metrics import confusion_matrix
import pandas as pd

from matplotlib import pyplot as plt
import seaborn as sn
import torch
import torchmetrics


def calculate_f1_score(
preds: Union[List, torch.Tensor], target: Union[List, torch.Tensor], average: str = "micro"
):
if type(preds) == list:
preds = torch.as_tensor(preds)
if type(target) == list:
target = torch.as_tensor(target)

number_of_classes = max(max(target).item(), max(preds).item()) + 1
score = torchmetrics.functional.f1_score(
preds=preds, target=target, average=average, num_classes=number_of_classes
).item()
return score


def get_cm_from_predictions(
y_true: List, y_pred: List, encoding: Dict, normalize: str = "true"
) -> pd.DataFrame:
"""
Set normalize to None to get absolute numbers.
"""
reverse_encoding = {v: k for k, v in encoding.items()}

classes = [reverse_encoding[i] for i in set(y_pred).union(set(y_true))]

cf_matrix = confusion_matrix(y_true, y_pred, normalize=normalize)

df = pd.DataFrame(
cf_matrix,
index=classes,
columns=classes,
).round(2)
return df


def plot_and_save_dual_cm(df_A: pd.DataFrame, df_B: pd.DataFrame, save_to_file: str):
"""
Visualizes two confusion matrices next to each other. Input format of df_A and df_B matches
output of 'get_cm_from_predictions'.
"""

fig, axes = plt.subplots(1, 2, figsize=(45, 20), constrained_layout=True)
fig.suptitle("Confusion Matrices A & B", fontsize=30)

sn.heatmap(df_A, annot=True, ax=axes[0])
sn.heatmap(df_B, annot=True, ax=axes[1])
axes[0].set_title("A", fontsize=25)
axes[1].set_title("B", fontsize=25)
axes[0].set_ylabel("True A", fontsize=20)
axes[0].set_xlabel("Predicted A", fontsize=20)
axes[1].set_ylabel("True B", fontsize=20)
axes[1].set_xlabel("Predicted B", fontsize=20)
axes[0].set_aspect("equal")
axes[1].set_aspect("equal")

plt.setp(axes[0].yaxis.get_majorticklabels(), rotation="horizontal")
plt.setp(axes[1].yaxis.get_majorticklabels(), rotation="horizontal")
plt.setp(axes[0].xaxis.get_majorticklabels(), rotation="vertical")
plt.setp(axes[1].xaxis.get_majorticklabels(), rotation="vertical")

fig.savefig(save_to_file, format="png")
return fig

0 comments on commit 2ba714d

Please sign in to comment.