Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add visualization tools #50

Merged
merged 11 commits into from
Oct 15, 2023
Prev Previous commit
Next Next commit
✨ Add % of data in cal. bins & improve plot
  • Loading branch information
o-laurent committed Oct 14, 2023
commit ce775b6e1aa7804fd490910f295b20c64f123e19
2 changes: 0 additions & 2 deletions auto_tutorials_source/tutorial_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@

# Compute and plot the calibration figure
cal_plot.compute()
cal_plot.plot()

# %%
# 5. Fitting the Scaler to Improve the Calibration
Expand Down Expand Up @@ -148,7 +147,6 @@
print(f"ECE after scaling - {cal*100:.3}%.")

cal_plot.compute()
cal_plot.plot()

# %%
# The top-label calibration should be improved.
Expand Down
6 changes: 2 additions & 4 deletions torch_uncertainty/routines/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,8 @@ def test_epoch_end(
self.test_ood_metrics.reset()

if isinstance(self.logger, TensorBoardLogger):
self.cal_plot.compute()
self.logger.experiment.add_figure(
"Calibration Plot", self.cal_plot.plot()[0]
"Calibration Plot", self.cal_plot.compute()[0]
)

if self.ood_detection:
Expand Down Expand Up @@ -569,9 +568,8 @@ def test_epoch_end(
self.test_ood_ens_metrics.reset()

if isinstance(self.logger, TensorBoardLogger):
self.cal_plot.compute()
self.logger.experiment.add_figure(
"Calibration Plot", self.cal_plot.plot()[0]
"Calibration Plot", self.cal_plot.compute()[0]
)

if self.ood_detection:
Expand Down
51 changes: 29 additions & 22 deletions torch_uncertainty/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,12 @@ def update(
else:
self.acc.append((preds.argmax(-1) == targets).cpu())

def compute(self) -> None:
"""Compute the calibration plot."""
def compute(self) -> Tuple[Figure, Axes]:
"""Compute and plot the calibration figure.

Returns:
Tuple[Figure, Axes]: The figure and axes of the plot.
"""
confidence = torch.cat(self.conf)
acc = torch.cat(self.acc)

Expand All @@ -88,41 +92,45 @@ def compute(self) -> None:
val, inverse, counts = bin_ids.unique(
return_inverse=True, return_counts=True
)
val = torch.nn.functional.one_hot(val.long(), num_classes=10)
val_oh = torch.nn.functional.one_hot(val.long(), num_classes=10)

# add 1e-6 to avoid division NaNs
self.values = (
val.T.float()
values = (
val_oh.T.float()
@ torch.sum(
acc.unsqueeze(1) * torch.nn.functional.one_hot(inverse).float(),
0,
)
/ (val.T @ counts + 1e-6).float()
/ (val_oh.T @ counts + 1e-6).float()
)
counts_all = (val_oh.T @ counts).float()
total = torch.sum(counts)

def plot(self) -> Tuple[Figure, Axes]:
"""Plot the calibration.

Returns:
Tuple[Figure, Axes]: The figure and axes of the plot.
"""
plt.rc("axes", axisbelow=True)
fig, ax = plt.subplots(1, figsize=self.figsize)
ax.hist(
x=[self.bin_width * i for i in range(self.num_bins)],
weights=self.values,
bins=[self.bin_width * i for i in range(self.num_bins + 1)],
x=[self.bin_width * i * 100 for i in range(self.num_bins)],
weights=values * 100,
bins=[self.bin_width * i * 100 for i in range(self.num_bins + 1)],
alpha=0.7,
linewidth=1,
edgecolor="#0d559f",
color="#1f77b4",
)
ax.plot([0, 1], [0, 1], "--", color="black")
for i, count in enumerate(counts_all):
ax.text(
3.0 + 9.9 * i,
1,
f"{int(count/total*100)}%",
fontsize=8,
)

ax.plot([0, 100], [0, 100], "--", color="#0d559f")
plt.grid(True, linestyle="--", alpha=0.7, zorder=0)
ax.set_xlabel("Top-class Confidence", fontsize=16)
ax.set_ylabel("Success Rate", fontsize=16)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel("Top-class Confidence (%)", fontsize=16)
ax.set_ylabel("Success Rate (%)", fontsize=16)
ax.set_xlim(0, 100)
ax.set_ylim(0, 100)
ax.set_aspect("equal", "box")
fig.tight_layout()
return fig, ax
Expand All @@ -140,8 +148,7 @@ def __call__(
Tuple[Figure, Axes]: The figure and axes of the plot.
"""
self.update(preds, targets)
self.compute()
return self.plot()
return self.compute()


def plot_hist(
Expand Down