Skip to content

Commit

Permalink
👕 Start improving eval loop code & simplify segformer tests & activat…
Browse files Browse the repository at this point in the history
…e codecov test analysis
  • Loading branch information
o-laurent committed Nov 19, 2024
1 parent 463c053 commit 8650ef3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 69 deletions.
4 changes: 0 additions & 4 deletions tests/models/test_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,4 @@ class TestSegformer:
def test_main(self):
model = seg_former(10, 0)
seg_former(10, 1)
seg_former(10, 2)
seg_former(10, 3)
seg_former(10, 4)
seg_former(10, 5)
model(torch.randn(1, 3, 32, 32))
2 changes: 1 addition & 1 deletion torch_uncertainty/models/segmentation/segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _get_embed_dims(arch: int) -> list[int]:
return [64, 128, 320, 512]


def _get_depths(arch: int) -> list[int]:
def _get_depths(arch: int) -> list[int]: # coverage: ignore
if arch == 0 or arch == 1:
return [2, 2, 2, 2]
if arch == 2:
Expand Down
104 changes: 40 additions & 64 deletions torch_uncertainty/utils/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@
from rich import get_console
from rich.console import Group
from rich.table import Table
from torch import Tensor

PERCENTAGE_METRICS = [
"Acc",
"AUPR",
"AUROC",
"FPR95",
"Cov@5Risk",
"Risk@80Cov",
"pixAcc",
"mIoU",
"AURC",
"AUGRC",
"mAcc",
]


def _add_row(table: Table, metric_name: str, value: Tensor) -> None:
if metric_name in PERCENTAGE_METRICS:
value = value * 100
table.add_row(metric_name, f"{value.item():.2f}%")
else:
table.add_row(metric_name, f"{value.item():.5f}")


class TUEvaluationLoop(_EvaluationLoop):
Expand All @@ -20,21 +43,6 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None:
# test/post: Post-Processing Metrics
# test/seg: Segmentation Metrics

# In percentage
percentage_metrics = [
"Acc",
"AUPR",
"AUROC",
"FPR95",
"Cov@5Risk",
"Risk@80Cov",
"pixAcc",
"mIoU",
"AURC",
"AUGRC",
"mAcc",
]

metrics = {}
for result in results:
for key, value in result.items():
Expand Down Expand Up @@ -88,64 +96,44 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None:
table.add_column(first_col_name, justify="center", style="cyan", width=12)
table.add_column("Classification", justify="center", style="magenta", width=25)
cls_metrics = OrderedDict(sorted(metrics["cls"].items()))
for metric, value in cls_metrics.items():
if metric in percentage_metrics:
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
for metric_name, value in cls_metrics.items():
_add_row(table, metric_name, value)
tables.append(table)

if "seg" in metrics:
table = Table()
table.add_column(first_col_name, justify="center", style="cyan", width=12)
table.add_column("Segmentation", justify="center", style="magenta", width=25)
seg_metrics = OrderedDict(sorted(metrics["seg"].items()))
for metric, value in seg_metrics.items():
if metric in percentage_metrics:
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
for metric_name, value in seg_metrics.items():
_add_row(table, metric_name, value)
tables.append(table)

if "reg" in metrics:
table = Table()
table.add_column(first_col_name, justify="center", style="cyan", width=12)
table.add_column("Regression", justify="center", style="magenta", width=25)
reg_metrics = OrderedDict(sorted(metrics["reg"].items()))
for metric, value in reg_metrics.items():
if metric in percentage_metrics: # coverage: ignore
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
for metric_name, value in reg_metrics.items():
_add_row(table, metric_name, value)
tables.append(table)

if "cal" in metrics:
table = Table()
table.add_column(first_col_name, justify="center", style="cyan", width=12)
table.add_column("Calibration", justify="center", style="magenta", width=25)
cal_metrics = OrderedDict(sorted(metrics["cal"].items()))
for metric, value in cal_metrics.items():
if metric in percentage_metrics:
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
for metric_name, value in cal_metrics.items():
_add_row(table, metric_name, value)
tables.append(table)

if "ood" in metrics:
table = Table()
table.add_column(first_col_name, justify="center", style="cyan", width=12)
table.add_column("OOD Detection", justify="center", style="magenta", width=25)
ood_metrics = OrderedDict(sorted(metrics["ood"].items()))
for metric, value in ood_metrics.items():
if metric in percentage_metrics:
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
for metric_name, value in ood_metrics.items():
_add_row(table, metric_name, value)
tables.append(table)

if "sc" in metrics:
Expand All @@ -158,25 +146,17 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None:
width=25,
)
sc_metrics = OrderedDict(sorted(metrics["sc"].items()))
for metric, value in sc_metrics.items():
if metric in percentage_metrics:
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
for metric_name, value in sc_metrics.items():
_add_row(table, metric_name, value)
tables.append(table)

if "post" in metrics:
table = Table()
table.add_column(first_col_name, justify="center", style="cyan", width=12)
table.add_column("Post-Processing", justify="center", style="magenta", width=25)
post_metrics = OrderedDict(sorted(metrics["post"].items()))
for metric, value in post_metrics.items():
if metric in percentage_metrics:
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
for metric_name, value in post_metrics.items():
_add_row(table, metric_name, value)
tables.append(table)

if "shift" in metrics:
Expand All @@ -190,14 +170,10 @@ def _print_results(results: list[_OUT_DICT], stage: str) -> None:
width=25,
)
shift_metrics = OrderedDict(sorted(metrics["shift"].items()))
for metric, value in shift_metrics.items():
if metric == "shift_severity":
for metric_name, value in shift_metrics.items():
if metric_name == "shift_severity":
continue
if metric in percentage_metrics:
value = value * 100
table.add_row(metric, f"{value.item():.2f}%")
else:
table.add_row(metric, f"{value.item():.5f}")
_add_row(table, metric_name, value)
tables.append(table)

console = get_console()
Expand Down

0 comments on commit 8650ef3

Please sign in to comment.