Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add visualization tools #50

Merged
merged 11 commits into from
Oct 15, 2023
Prev Previous commit
Next Next commit
✨ Plot ood criteria
  • Loading branch information
o-laurent committed Oct 14, 2023
commit 70babd1daa8a1d99a8dbf425ad6bb91380e7a806
89 changes: 76 additions & 13 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytorch_lightning.utilities.memory import get_model_size_mb
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from timm.data import Mixup
from torch import nn
from torch import Tensor, nn
from torchmetrics import Accuracy, CalibrationError, MetricCollection
from torchmetrics.classification import (
BinaryAccuracy,
Expand All @@ -31,7 +31,7 @@
NegativeLogLikelihood,
VariationRatio,
)
from ..visualization import CalibrationPlot
from ..visualization import CalibrationPlot, plot_hist


# fmt:on
Expand Down Expand Up @@ -165,7 +165,7 @@ def criterion(self) -> nn.Module:
self.loss = partial(self.loss, model=self.model)
return self.loss()

def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, input: Tensor) -> Tensor:
return self.model.forward(input)

def on_train_start(self) -> None:
Expand All @@ -191,7 +191,7 @@ def on_train_start(self) -> None:
)

def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
self, batch: Tuple[Tensor, Tensor], batch_idx: int
) -> STEP_OUTPUT:
batch = self.mixup(*batch)
inputs, targets = self.format_batch_fn(batch)
Expand All @@ -210,7 +210,7 @@ def training_step(
return loss

def validation_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
self, batch: Tuple[Tensor, Tensor], batch_idx: int
) -> None:
inputs, targets = batch
logits = self.forward(inputs)
Expand All @@ -230,10 +230,10 @@ def validation_epoch_end(

def test_step(
self,
batch: Tuple[torch.Tensor, torch.Tensor],
batch: Tuple[Tensor, Tensor],
batch_idx: int,
dataloader_idx: Optional[int] = 0,
) -> None:
) -> Tensor:
inputs, targets = batch
logits = self.forward(inputs)

Expand Down Expand Up @@ -274,6 +274,7 @@ def test_step(
on_epoch=True,
add_dataloader_idx=False,
)
return logits

def test_epoch_end(
self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
Expand All @@ -295,6 +296,28 @@ def test_epoch_end(
"Calibration Plot", self.cal_plot.plot()[0]
)

if self.ood_detection:
id_logits = torch.cat(outputs[0], 0).float().cpu()
ood_logits = torch.cat(outputs[1], 0).float().cpu()

id_probs = F.softmax(id_logits, dim=-1)
ood_probs = F.softmax(ood_logits, dim=-1)

logits_fig = plot_hist(
[id_logits.max(-1).values, ood_logits.max(-1).values],
20,
"Histogram of the logits",
)[0]
probs_fig = plot_hist(
[id_probs.max(-1).values, ood_probs.max(-1).values],
20,
"Histogram of the likelihoods",
)[0]
self.logger.experiment.add_figure("Logit Histogram", logits_fig)
self.logger.experiment.add_figure(
"Likelihood Histogram", probs_fig
)

@staticmethod
def add_model_specific_args(
parent_parser: ArgumentParser,
Expand Down Expand Up @@ -404,7 +427,6 @@ def __init__(
)

def on_train_start(self) -> None:
# hyperparameters for performances
param = {}
param["storage"] = f"{get_model_size_mb(self)} MB"
if self.logger is not None:
Expand Down Expand Up @@ -432,7 +454,7 @@ def on_train_start(self) -> None:
)

def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
self, batch: Tuple[Tensor, Tensor], batch_idx: int
) -> STEP_OUTPUT:
batch = self.mixup(*batch)

Expand All @@ -450,7 +472,7 @@ def training_step(
return loss

def validation_step( # type: ignore
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
self, batch: Tuple[Tensor, Tensor], batch_idx: int
) -> None:
inputs, targets = batch
logits = self.forward(inputs)
Expand All @@ -465,10 +487,10 @@ def validation_step( # type: ignore

def test_step(
self,
batch: Tuple[torch.Tensor, torch.Tensor],
batch: Tuple[Tensor, Tensor],
batch_idx: int,
dataloader_idx: Optional[int] = 0,
) -> None:
) -> Tensor:
inputs, targets = batch
logits = self.forward(inputs)
logits = rearrange(logits, "(n b) c -> b n c", n=self.num_estimators)
Expand Down Expand Up @@ -525,11 +547,16 @@ def test_step(
on_epoch=True,
add_dataloader_idx=False,
)
return logits

def test_epoch_end(
self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]
) -> None:
super().test_epoch_end(outputs)
self.log_dict(
self.test_cls_metrics.compute(),
)
self.test_cls_metrics.reset()

self.log_dict(
self.test_id_ens_metrics.compute(),
)
Expand All @@ -541,6 +568,42 @@ def test_epoch_end(
)
self.test_ood_ens_metrics.reset()

if isinstance(self.logger, TensorBoardLogger):
self.cal_plot.compute()
self.logger.experiment.add_figure(
"Calibration Plot", self.cal_plot.plot()[0]
)

if self.ood_detection:
id_logits = torch.cat(outputs[0], 0).float().cpu()
ood_logits = torch.cat(outputs[1], 0).float().cpu()

print(id_logits.shape)

id_probs = F.softmax(id_logits, dim=-1)
ood_probs = F.softmax(ood_logits, dim=-1)

logits_fig = plot_hist(
[
id_logits.mean(1).max(-1).values,
ood_logits.mean(1).max(-1).values,
],
20,
"Histogram of the logits",
)[0]
probs_fig = plot_hist(
[
id_probs.mean(1).max(-1).values,
ood_probs.mean(1).max(-1).values,
],
20,
"Histogram of the likelihoods",
)[0]
self.logger.experiment.add_figure("Logit Histogram", logits_fig)
self.logger.experiment.add_figure(
"Likelihood Histogram", probs_fig
)

@staticmethod
def add_model_specific_args(
parent_parser: ArgumentParser,
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/visualization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# fmt: off
from typing import Any, Tuple
from typing import List, Tuple

import matplotlib.pyplot as plt
import seaborn as sns
Expand Down Expand Up @@ -141,7 +141,7 @@ def __call__(


def plot_hist(
conf: Any,
conf: List[torch.Tensor],
bins: int = 20,
title: str = "Histogram with 'auto' bins",
dpi: int = 60,
Expand Down
Loading