Skip to content

Commit

Permalink
fix missing updates
Browse files Browse the repository at this point in the history
  • Loading branch information
drprojects committed Jun 30, 2024
1 parent 7af051f commit 9b6ac18
Show file tree
Hide file tree
Showing 21 changed files with 487 additions and 76 deletions.
9 changes: 5 additions & 4 deletions src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,11 @@ def semantic_segmentation_oracle(

# Performance evaluation
from src.metrics import ConfusionMatrix
metric = ConfusionMatrix(num_classes, *metric_args, **metric_kwargs)
metric(pred.cpu(), target.cpu())
cm = ConfusionMatrix(num_classes, *metric_args, **metric_kwargs)
cm(pred.cpu(), target.cpu())
metrics = cm.all_metrics()

return metric.miou(), metric.iou(), metric.oa(), metric.macc()
return metrics

def instance_segmentation_oracle(self, *metric_args, **metric_kwargs):
"""Compute the oracle performance for instance segmentation.
Expand Down Expand Up @@ -929,7 +930,7 @@ def from_data_list(cls, data_list, follow_batch=None, exclude_keys=None):
# and 'obj' to a proper InstanceBatch.
# Note we will need to do the same in `get_example` to avoid
# breaking PyG Batch mechanisms
if batch.is_super:
if batch.is_super and isinstance(batch.sub, Cluster):
batch.sub = ClusterBatch.from_list(batch.sub)
if batch.obj is not None:
batch.obj = InstanceBatch.from_list(batch.obj)
Expand Down
7 changes: 4 additions & 3 deletions src/data/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,10 +752,11 @@ def semantic_segmentation_oracle(

# Performance evaluation
from src.metrics import ConfusionMatrix
metric = ConfusionMatrix(num_classes, *metric_args, **metric_kwargs)
metric(pred.cpu(), target.cpu())
cm = ConfusionMatrix(num_classes, *metric_args, **metric_kwargs)
cm(pred.cpu(), target.cpu())
metrics = cm.all_metrics()

return metric.miou(), metric.iou(), metric.oa(), metric.macc()
return metrics

def oracle(self, num_classes):
"""Compute the oracle predictions for instance and panoptic
Expand Down
10 changes: 10 additions & 0 deletions src/data/nag.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def num_points(self):
"""Number of points/nodes in the lower-level graph."""
return [d.num_points for d in self] if self.num_levels > 0 else 0

@property
def level_ratios(self):
"""Ratios of number of nodes between consecutive partition
levels. This can be useful for investigating how much each
partition level 'compresses' the previous one.
"""
return {
f"|P_{i}| / |P_{i+1}|": self.num_points[i] / self.num_points[i + 1]
for i in range(self.num_levels - 1)}

def to_list(self):
"""Return the Data list"""
return self._list
Expand Down
17 changes: 3 additions & 14 deletions src/datamodules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,10 @@ def setup(self, stage=None):
def set_transforms(self):
"""Parse in self.hparams in search for '*transform*' keys and
instantiate the corresponding transforms.
Credit: https://github.com/torch-points3d/torch-points3d
"""
for key_name in self.hparams.keys():
if "transform" in key_name:
name = key_name.replace("transforms", "transform")
params = getattr(self.hparams, key_name, None)
if params is None:
continue
try:
transform = instantiate_transforms(params)
except Exception:
log.exception(f"Error trying to create {name}, {params}")
continue
setattr(self, name, transform)
t_dict = instantiate_datamodule_transforms(self.hparams, log=log)
for key, transform in t_dict.items():
setattr(self, key, transform)

def check_tta_conflicts(self):
"""Make sure the transforms are Test-Time Augmentation-friendly
Expand Down
5 changes: 3 additions & 2 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# - automatically loads environment variables from ".env" file if exists
#
# how it works:
# - the line above recursively searches for either ".git" or "pyproject.toml" in present
# - the line above recursively searches for either ".git" or "README.md" in present
# and parent dirs, to determine the project root dir
# - adds root dir to the PYTHONPATH (if `pythonpath=True`), so this file can be run from
# any place without installing project as a package
Expand Down Expand Up @@ -49,7 +49,8 @@
# Registering the "eval" resolver allows for advanced config
# interpolation with arithmetic operations:
# https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html
OmegaConf.register_new_resolver("eval", eval)
if not OmegaConf.has_resolver('eval'):
OmegaConf.register_new_resolver('eval', eval)

log = utils.get_pylogger(__name__)

Expand Down
28 changes: 28 additions & 0 deletions src/metrics/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from torchmetrics.classification import MulticlassConfusionMatrix
from torch_scatter import scatter_add
from src.metrics.mean_average_precision import BaseMetricResults


log = logging.getLogger(__name__)
Expand All @@ -12,6 +13,17 @@
__all__ = ['ConfusionMatrix']


class SemanticMetricResults(BaseMetricResults):
"""Class to wrap the final metric results for Semantic Segmentation.
"""
__slots__ = (
'oa',
'macc',
'miou',
'iou_per_class',
'seen_class')


class ConfusionMatrix(MulticlassConfusionMatrix):
"""TorchMetrics's MulticlassConfusionMatrix but tailored to our
needs. In particular, new methods allow computing OA, mAcc, mIoU
Expand Down Expand Up @@ -231,6 +243,22 @@ def print_metrics(self, class_names=None):
continue
print(f' {c:<13}: {iou:0.2f}')

def all_metrics(self, as_percent=True):
"""Helper to return all important metrics, stored in a
`SemanticMetricResults` object print the OA, mAcc, mIoU and per-class IoU.
:param as_percent: bool
If True, the returned metric is expressed in [0, 100]
"""
metrics = SemanticMetricResults()
metrics.oa = self.oa(as_percent=as_percent)
metrics.macc = self.macc(as_percent=as_percent)
metrics.miou = self.miou(as_percent=as_percent)
iou, seen = self.iou(as_percent=as_percent)
metrics.iou_per_class = iou
metrics.seen_class = seen
return metrics


def save_confusion_matrix(cm, path2save, ordered_names):
import seaborn as sns
Expand Down
Loading

0 comments on commit 9b6ac18

Please sign in to comment.