Skip to content

Commit

Permalink
Remove internal metrics in favor of torchmetrics (#4287)
Browse files Browse the repository at this point in the history
* deprecate in favor of TM

* prune

* require

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* imports

* avg None

* reduction none

* prune IoU

* hard removal

* update metric computation

* Apply suggestions from code review

Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* typo

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
4 people authored Mar 21, 2022
1 parent a581be6 commit b0d9e75
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 297 deletions.
37 changes: 23 additions & 14 deletions examples/dgcnn_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import torch
import torch.nn.functional as F
from torch_scatter import scatter
from torchmetrics.functional import jaccard_index

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, DynamicEdgeConv
from torch_geometric.utils import intersection_and_union as i_and_u

category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
Expand Down Expand Up @@ -80,24 +81,32 @@ def train():
def test(loader):
model.eval()

y_mask = loader.dataset.y_mask
ious = [[] for _ in range(len(loader.dataset.categories))]

ious, categories = [], []
y_map = torch.empty(loader.dataset.num_classes, device=device).long()
for data in loader:
data = data.to(device)
pred = model(data).argmax(dim=1)
outs = model(data)

sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
data.category.tolist()):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
part = torch.tensor(part, device=device)

y_map[part] = torch.arange(part.size(0), device=device)

iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],
num_classes=part.size(0), absent_score=1.0)
ious.append(iou)

i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch)
iou = i.cpu().to(torch.float) / u.cpu().to(torch.float)
iou[torch.isnan(iou)] = 1
categories.append(data.category)

# Find and filter the relevant classes for each category.
for iou, category in zip(iou.unbind(), data.category.unbind()):
ious[category.item()].append(iou[y_mask[category]])
iou = torch.tensor(ious, device=device)
category = torch.cat(categories, dim=0)

# Compute mean IoU.
ious = [torch.stack(iou).mean(0).mean(0) for iou in ious]
return torch.tensor(ious).mean().item()
mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU.
return float(mean_iou.mean()) # Global IoU.


for epoch in range(1, 31):
Expand Down
37 changes: 23 additions & 14 deletions examples/point_transformer_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from torch.nn import ReLU
from torch.nn import Sequential as Seq
from torch_cluster import knn_graph
from torch_scatter import scatter
from torchmetrics.functional import jaccard_index

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn.unpool import knn_interpolate
from torch_geometric.utils import intersection_and_union as i_and_u

category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
Expand Down Expand Up @@ -197,24 +198,32 @@ def train():
def test(loader):
model.eval()

y_mask = loader.dataset.y_mask
ious = [[] for _ in range(len(loader.dataset.categories))]

ious, categories = [], []
y_map = torch.empty(loader.dataset.num_classes, device=device).long()
for data in loader:
data = data.to(device)
pred = model(data.x, data.pos, data.batch).argmax(dim=1)
outs = model(data.x, data.pos, data.batch)

sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
data.category.tolist()):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
part = torch.tensor(part, device=device)

y_map[part] = torch.arange(part.size(0), device=device)

iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],
num_classes=part.size(0), absent_score=1.0)
ious.append(iou)

i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch)
iou = i.cpu().to(torch.float) / u.cpu().to(torch.float)
iou[torch.isnan(iou)] = 1
categories.append(data.category)

# Find and filter the relevant classes for each category.
for iou, category in zip(iou.unbind(), data.category.unbind()):
ious[category.item()].append(iou[y_mask[category]])
iou = torch.tensor(ious, device=device)
category = torch.cat(categories, dim=0)

# Compute mean IoU.
ious = [torch.stack(iou).mean(0).mean(0) for iou in ious]
return torch.tensor(ious).mean().item()
mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU.
return float(mean_iou.mean()) # Global IoU.


for epoch in range(1, 100):
Expand Down
37 changes: 23 additions & 14 deletions examples/pointnet2_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import torch
import torch.nn.functional as F
from pointnet2_classification import GlobalSAModule, SAModule
from torch_scatter import scatter
from torchmetrics.functional import jaccard_index

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, knn_interpolate
from torch_geometric.utils import intersection_and_union as i_and_u

category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
Expand Down Expand Up @@ -106,24 +107,32 @@ def train():
def test(loader):
model.eval()

y_mask = loader.dataset.y_mask
ious = [[] for _ in range(len(loader.dataset.categories))]

ious, categories = [], []
y_map = torch.empty(loader.dataset.num_classes, device=device).long()
for data in loader:
data = data.to(device)
pred = model(data).argmax(dim=1)
outs = model(data)

sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
for out, y, category in zip(outs.split(sizes), data.y.split(sizes),
data.category.tolist()):
category = list(ShapeNet.seg_classes.keys())[category]
part = ShapeNet.seg_classes[category]
part = torch.tensor(part, device=device)

y_map[part] = torch.arange(part.size(0), device=device)

iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],
num_classes=part.size(0), absent_score=1.0)
ious.append(iou)

i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch)
iou = i.cpu().to(torch.float) / u.cpu().to(torch.float)
iou[torch.isnan(iou)] = 1
categories.append(data.category)

# Find and filter the relevant classes for each category.
for iou, category in zip(iou.unbind(), data.category.unbind()):
ious[category.item()].append(iou[y_mask[category]])
iou = torch.tensor(ious, device=device)
category = torch.cat(categories, dim=0)

# Compute mean IoU.
ious = [torch.stack(iou).mean(0).mean(0) for iou in ious]
return torch.tensor(ious).mean().item()
mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU.
return float(mean_iou.mean()) # Global IoU.


for epoch in range(1, 31):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'matplotlib',
'scikit-image',
'pytorch-memlab',
'torchmetrics>=0.7',
'class-resolver>=0.3.2',
]

Expand Down
41 changes: 0 additions & 41 deletions test/utils/test_metric.py

This file was deleted.

13 changes: 0 additions & 13 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
structured_negative_sampling,
structured_negative_sampling_feasible)
from .train_test_split_edges import train_test_split_edges
from .metric import (accuracy, true_positive, true_negative, false_positive,
false_negative, precision, recall, f1_score,
intersection_and_union, mean_iou)

__all__ = [
'degree',
Expand Down Expand Up @@ -82,16 +79,6 @@
'structured_negative_sampling',
'structured_negative_sampling_feasible',
'train_test_split_edges',
'accuracy',
'true_positive',
'true_negative',
'false_positive',
'false_negative',
'precision',
'recall',
'f1_score',
'intersection_and_union',
'mean_iou',
]

classes = __all__
Loading

0 comments on commit b0d9e75

Please sign in to comment.