Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchmetrics #511

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- cuda-version # works also with CPU-only system.
- pytorch >=1.12
- lightning >=2.0
- torchmetrics >=0.7.0,<0.11
- torchmetrics
- ogb
- pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric`
- wandb
Expand Down
907 changes: 0 additions & 907 deletions graphium/ipu/ipu_metrics.py

This file was deleted.

105 changes: 92 additions & 13 deletions graphium/trainer/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import torch
from torch import Tensor
import operator as op
from copy import deepcopy

from torch.nn.modules.loss import _Loss
from torchmetrics.utilities.distributed import reduce
import torchmetrics.functional.regression.mae

Expand Down Expand Up @@ -137,7 +139,7 @@ class MetricWrapper:

def __init__(
self,
metric: Union[str, Callable],
metric: Union[str, torchmetrics.Metric, torch.nn.modules.loss._Loss],
threshold_kwargs: Optional[Dict[str, Any]] = None,
target_nan_mask: Optional[Union[str, int]] = None,
multitask_handling: Optional[str] = None,
Expand Down Expand Up @@ -187,7 +189,7 @@ def __init__(
Other arguments to call with the metric
"""

self.metric, self.metric_name = self._get_metric(metric)
metric_class, self.metric_name = self._get_metric_class(metric)
self.thresholder = None
if threshold_kwargs is not None:
self.thresholder = Thresholder(**threshold_kwargs)
Expand All @@ -198,6 +200,26 @@ def __init__(
self.target_to_int = target_to_int
self.kwargs = kwargs

self.metric, self.kwargs = self._initialize_metric(metric_class, self.kwargs)

@staticmethod
def _initialize_metric(metric, kwargs):
r"""
Initialize the metric with the provided kwargs
"""

if not isinstance(metric, type):
if not isinstance(metric, torchmetrics.Metric):
raise ValueError(f"metric must be a torchmetrics.Metric, provided: {type(metric)}"
f"Use `METRICS_DICT` to get the metric class")
else:
return metric, kwargs

metric = metric(**kwargs)

return metric, kwargs


@staticmethod
def _parse_target_nan_mask(target_nan_mask):
"""
Expand Down Expand Up @@ -254,7 +276,7 @@ def _parse_multitask_handling(multitask_handling, target_nan_mask):
return multitask_handling

@staticmethod
def _get_metric(metric):
def _get_metric_class(metric):
from graphium.utils.spaces import METRICS_DICT

if isinstance(metric, str):
Expand All @@ -265,9 +287,10 @@ def _get_metric(metric):
metric = metric
return metric, metric_name

def compute(self, preds: Tensor, target: Tensor) -> Tensor:
def update(self, preds: Tensor, target: Tensor) -> Tensor:
r"""
Compute the metric, apply the thresholder if provided, and manage the NaNs
Update the parameters of the metric, apply the thresholder if provided, and manage the NaNs.
See `torchmetrics.Metric.update` for more details.
"""
if preds.ndim == 1:
preds = preds.unsqueeze(-1)
Expand Down Expand Up @@ -300,7 +323,8 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor:
target = target.squeeze()
if self.target_to_int:
target = target.to(int)
metric_val = self.metric(preds, target, **self.kwargs)
self.metric.update(preds, target)

elif self.multitask_handling == "flatten":
# Flatten the tensors, apply the nan filtering, then compute the metrics
if classifigression:
Expand All @@ -313,7 +337,8 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor:
target = target.squeeze()
if self.target_to_int:
target = target.to(int)
metric_val = self.metric(preds, target, **self.kwargs)
self.metric.update(preds, target)

elif self.multitask_handling == "mean-per-label":
# Loop the columns (last dim) of the tensors, apply the nan filtering, compute the metrics per column, then average the metrics
target_list = [target[..., ii][~target_nans[..., ii]] for ii in range(target.shape[-1])]
Expand All @@ -322,24 +347,52 @@ def compute(self, preds: Tensor, target: Tensor) -> Tensor:
preds_list = [preds[..., i, :][~target_nans[..., i]] for i in range(preds.shape[1])]
else:
preds_list = [preds[..., ii][~target_nans[..., ii]] for ii in range(preds.shape[-1])]
metric_val = []

if not isinstance(self.metric, list):
self.metric = [deepcopy(self.metric) for _ in range(len(target_list))]
for ii in range(len(target_list)):
try:
this_preds, this_target = self._filter_nans(preds_list[ii], target_list[ii])
if self.squeeze_targets:
this_target = this_target.squeeze()
if self.target_to_int:
this_target = this_target.to(int)
metric_val.append(self.metric(this_preds, this_target, **self.kwargs))
self.metric[ii].update(this_preds, this_target)
except:
pass
# Average the metric
metric_val = nan_mean(torch.stack(metric_val))
else:
# Wrong option
raise ValueError(f"Invalid option `self.multitask_handling={self.multitask_handling}`")

return metric_val
def compute(self) -> Tensor:
r"""
Compute the metric with the method `self.compute`
"""
if self.multitask_handling == "mean-per-label":
metrics = [metric.compute() for metric in self.metric]
return nan_mean(torch.stack(metrics))

return self.metric.compute()

def update_compute(self, preds: Tensor, target: Tensor) -> Tensor:
r"""
Update the parameters of the metric, apply the thresholder if provided, and manage the NaNs.
Then compute the metric with the method `self.compute`
"""

self.update(preds, target)
return self.compute()

def reset(self):
r"""
Reset the metric with the method `self.metric.reset`
"""
if self.multitask_handling == "mean-per-label":
for metric in self.metric:
metric.reset()
else:
self.metric.reset()


def _filter_nans(self, preds: Tensor, target: Tensor):
"""Handle the NaNs according to the chosen options"""
Expand Down Expand Up @@ -405,10 +458,36 @@ def __getstate__(self):

def __setstate__(self, state: dict):
"""Reload the class from pickling."""
state["metric"], state["metric_name"] = self._get_metric(state["metric"])
state["metric"], state["metric_name"] = self._get_metric_class(state["metric"])
thresholder = state.pop("threshold_kwargs", None)
if thresholder is not None:
thresholder = Thresholder(**thresholder)
state["thresholder"] = thresholder
state["metric"], state["at_compute_kwargs"] = self._initialize_metric(state["metric"], state["kwargs"])

self.__dict__.update(state)

class MetricToTorchMetrics():
r"""
A simple wrapper to convert any metric or loss to an equivalent of `torchmetrics.Metric`
by adding the `update`, `compute`, and `reset` methods to make it compatible with `MetricWrapper`.
However, it is simply limited to computing the average of the metric over all the updates.
"""

def __init__(self, metric):
self.metric = metric
self.scores = []

def update(self, preds: Tensor, target: Tensor):
self.scores.append(self.metric(preds, target))

def compute(self):
if len(self.scores) == 0:
raise ValueError("No scores to compute")
elif len(self.scores) == 1:
return self.scores[0]
return nan_mean(torch.stack(self.scores))

def reset(self):
self.scores = []

Loading
Loading