Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code cleaning after classification refactor 2/n #1252

Merged
merged 77 commits into from
Nov 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
f0f8dce
functionals cleanup
SkafteNicki Oct 5, 2022
a2c42b4
remove old functions
SkafteNicki Oct 6, 2022
8f3328c
Merge branch 'master' into cleanup/2
SkafteNicki Oct 6, 2022
a0a3d3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2022
b0d8e06
Merge branch 'master' into cleanup/2
SkafteNicki Oct 6, 2022
80815f9
revert stat score due to dice
SkafteNicki Oct 6, 2022
54960dc
Merge branch 'master' into cleanup/2
Borda Oct 12, 2022
d783889
Merge branch 'master' into cleanup/2
Borda Oct 22, 2022
67b8dcc
Merge branch 'master' into cleanup/2
SkafteNicki Oct 25, 2022
4e17049
Merge branch 'cleanup/2' of https://github.com/PyTorchLightning/metri…
SkafteNicki Oct 25, 2022
16b4295
clean docstring
SkafteNicki Oct 25, 2022
282751b
more docstring cleaning
SkafteNicki Oct 25, 2022
7e246a1
remaining changes to impl
SkafteNicki Oct 31, 2022
7b86eab
Merge branch 'master' into cleanup/2
SkafteNicki Oct 31, 2022
619ddf6
fix imports
SkafteNicki Oct 31, 2022
3a7d89b
Merge branch 'cleanup/2' of https://github.com/PyTorchLightning/metri…
SkafteNicki Oct 31, 2022
6a1df14
remove old class interfaces
SkafteNicki Nov 5, 2022
6dded5b
Merge branch 'master' into cleanup/2
SkafteNicki Nov 5, 2022
08958e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2022
6fc8573
remove old warning
SkafteNicki Nov 5, 2022
8b72a67
changelog
SkafteNicki Nov 5, 2022
397c0d2
change to object
SkafteNicki Nov 5, 2022
d11b8ee
remove old docstrings
SkafteNicki Nov 5, 2022
d9f14ff
merge
SkafteNicki Nov 5, 2022
220f0d3
fix a lot of issues
SkafteNicki Nov 5, 2022
66e9d52
fix docstrings
SkafteNicki Nov 5, 2022
16f8042
fix arg ordering
SkafteNicki Nov 5, 2022
5c93067
fix
SkafteNicki Nov 5, 2022
bd9478a
another fix
SkafteNicki Nov 5, 2022
19027ac
doc fix
SkafteNicki Nov 5, 2022
ae7379d
fix import
SkafteNicki Nov 5, 2022
07c40ef
try fixing docs
SkafteNicki Nov 5, 2022
cc79a76
fix integration testing
SkafteNicki Nov 5, 2022
ca1b607
fix top_k arg
SkafteNicki Nov 5, 2022
20589ce
missing assert
SkafteNicki Nov 6, 2022
86a5512
get dice working
SkafteNicki Nov 8, 2022
49a6be1
fix docs
SkafteNicki Nov 8, 2022
cf2bbd7
mistake fix
SkafteNicki Nov 8, 2022
8fccc4b
fix
SkafteNicki Nov 8, 2022
284129d
update doctests
SkafteNicki Nov 8, 2022
204bb1f
fix
SkafteNicki Nov 8, 2022
3455f36
please work
SkafteNicki Nov 8, 2022
4ec6a0a
fix doctests
SkafteNicki Nov 8, 2022
d8ca9a6
Merge branch 'master' into cleanup/2
SkafteNicki Nov 8, 2022
194469c
fix
SkafteNicki Nov 8, 2022
82212b9
Merge branch 'master' into cleanup/2
SkafteNicki Nov 9, 2022
0f17037
Merge branch 'master' into cleanup/2
SkafteNicki Nov 10, 2022
7c2da30
fix broken tests
SkafteNicki Nov 10, 2022
33e5604
fix more unittests
SkafteNicki Nov 10, 2022
268de19
fix unittesting
SkafteNicki Nov 11, 2022
ebaa7cf
Merge branch 'cleanup/2' of https://github.com/PyTorchLightning/metri…
SkafteNicki Nov 11, 2022
ea8c997
fix unittests
SkafteNicki Nov 11, 2022
37a47c3
fix unittests
SkafteNicki Nov 11, 2022
4f8997c
fix readme
SkafteNicki Nov 11, 2022
8bfb106
Merge branch 'master' into cleanup/2
Borda Nov 11, 2022
0b98a0c
fix more unittests
SkafteNicki Nov 13, 2022
6065fdb
Merge branch 'master' into cleanup/2
SkafteNicki Nov 14, 2022
e3f19e0
another fix
SkafteNicki Nov 14, 2022
cbab203
Merge branch 'master' into cleanup/2
SkafteNicki Nov 16, 2022
305b860
Merge branch 'master' into cleanup/2
SkafteNicki Nov 17, 2022
1f422c0
rev doctests
Borda Nov 17, 2022
7cbbabf
rev doctests
Borda Nov 17, 2022
7206fd0
Merge branch 'master' into cleanup/2
Borda Nov 17, 2022
228b6a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
69725b3
tasks
Borda Nov 17, 2022
146c014
Merge branch 'master' into cleanup/2
mergify[bot] Nov 18, 2022
4a4cb50
Merge branch 'master' into cleanup/2
mergify[bot] Nov 18, 2022
eaef7ec
Merge branch 'master' into cleanup/2
mergify[bot] Nov 19, 2022
b51a1d9
Merge branch 'master' into cleanup/2
mergify[bot] Nov 21, 2022
1f8054d
fix doctests
SkafteNicki Nov 22, 2022
8a03195
fix more doctests
SkafteNicki Nov 22, 2022
6c6e6d4
fix more doctests
SkafteNicki Nov 22, 2022
6a726fd
fix mypy
SkafteNicki Nov 22, 2022
2202d4d
Merge branch 'master' into cleanup/2
SkafteNicki Nov 22, 2022
80c4e1e
Merge branch 'master' into cleanup/2
mergify[bot] Nov 22, 2022
694421a
Merge branch 'master' into cleanup/2
mergify[bot] Nov 22, 2022
2037602
Merge branch 'master' into cleanup/2
mergify[bot] Nov 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Changed minimum Pytorch version to be 1.8 ([#1263](https://github.com/Lightning-AI/metrics/pull/1263))


- Changed interface for all functional and modular classification metrics after refactor ([#1252](https://github.com/Lightning-AI/metrics/pull/1252))


### Deprecated

-
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ import torch
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()
metric = torchmetrics.Accuracy(task="multiclass", num_classes=5)

# move the metric to device you want computations to take place
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -169,7 +169,7 @@ def metric_ddp(rank, world_size):
dist.init_process_group("gloo", rank=rank, world_size=world_size)

# initialize model
metric = torchmetrics.Accuracy()
metric = torchmetrics.Accuracy(task="multiclass", num_classes=5)

# define a model and append your metric to it
# this allows metric states to be placed on correct accelerators when
Expand Down Expand Up @@ -263,7 +263,9 @@ import torchmetrics
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)
acc = torchmetrics.functional.classification.multiclass_accuracy(
preds, target, num_classes=5
)
```

### Covered domains and example metrics
Expand Down
16 changes: 0 additions & 16 deletions docs/source/classification/precision_recall.rst

This file was deleted.

14 changes: 7 additions & 7 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The example below shows how to use a metric in your `LightningModule <https://py

def __init__(self):
...
self.accuracy = torchmetrics.Accuracy()
self.accuracy = torchmetrics.Accuracy(task='multiclass')

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -80,8 +80,8 @@ value by calling ``.compute()``.

def __init__(self):
...
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')

def training_step(self, batch, batch_idx):
x, y = batch
Expand All @@ -105,8 +105,8 @@ of the metrics.

def __init__(self):
...
self.train_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy()
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -143,7 +143,7 @@ mixed as it can lead to wrong results.

def __init__(self):
...
self.valid_acc = torchmetrics.Accuracy()
self.valid_acc = torchmetrics.Accuracy(task='multiclass')

def validation_step(self, batch, batch_idx):
logits = self(x)
Expand Down Expand Up @@ -187,7 +187,7 @@ The following contains a list of pitfalls to be aware of:

def __init__(self):
...
self.val_acc = nn.ModuleList([torchmetrics.Accuracy() for _ in range(2)])
self.val_acc = nn.ModuleList([torchmetrics.Accuracy(task='multiclass') for _ in range(2)])

def val_dataloader(self):
return [DataLoader(...), DataLoader(...)]
Expand Down
43 changes: 24 additions & 19 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us

.. code-block:: python

from torchmetrics.classification import Accuracy
from torchmetrics.classification import BinaryAccuracy

train_accuracy = Accuracy()
valid_accuracy = Accuracy()
train_accuracy = BinaryAccuracy()
valid_accuracy = BinaryAccuracy()

for epoch in range(epochs):
for x, y in train_data:
Expand Down Expand Up @@ -84,14 +84,14 @@ be moved to the same device as the input of the metric:

.. code-block:: python

from torchmetrics import Accuracy
from torchmetrics.classification import BinaryAccuracy

target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))

# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
confmat = BinaryAccuracy().to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0

Expand All @@ -107,16 +107,17 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.

.. testcode::

from torchmetrics import Accuracy, MetricCollection
from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy

class MyModule(torch.nn.Module):
def __init__(self):
...
# valid ways metrics will be identified as child modules
self.metric1 = Accuracy()
self.metric2 = nn.ModuleList(Accuracy())
self.metric3 = nn.ModuleDict({'accuracy': Accuracy()})
self.metric4 = MetricCollection([Accuracy()]) # torchmetrics build-in collection class
self.metric1 = BinaryAccuracy()
self.metric2 = nn.ModuleList(BinaryAccuracy())
self.metric3 = nn.ModuleDict({'accuracy': BinaryAccuracy()})
self.metric4 = MetricCollection([BinaryAccuracy()]) # torchmetrics build-in collection class

def forward(self, batch):
data, target = batch
Expand Down Expand Up @@ -254,33 +255,37 @@ Example:

.. testcode::

from torchmetrics import MetricCollection, Accuracy, Precision, Recall
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
Accuracy(),
Precision(num_classes=3, average='macro'),
Recall(num_classes=3, average='macro')
MulticlassAccuracy(num_classes=3, average="micro"),
MulticlassPrecision(num_classes=3, average="macro"),
MulticlassRecall(num_classes=3, average="macro")
])
print(metric_collection(preds, target))

.. testoutput::
:options: +NORMALIZE_WHITESPACE

{'Accuracy': tensor(0.1250),
'Precision': tensor(0.0667),
'Recall': tensor(0.1111)}
{'MulticlassAccuracy': tensor(0.1250),
'MulticlassPrecision': tensor(0.0667),
'MulticlassRecall': tensor(0.1111)}

Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule. In most cases we just have to replace ``self.log`` with ``self.log_dict``.

.. testcode::

from torchmetrics import Accuracy, MetricCollection, Precision, Recall
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall

class MyModule(LightningModule):
def __init__(self):
metrics = MetricCollection([Accuracy(), Precision(), Recall()])
metrics = MetricCollection([
MulticlassAccuracy(), MulticlassPrecision(), MulticlassRecall()
])
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')

Expand Down
8 changes: 5 additions & 3 deletions docs/source/pages/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ The code-snippet below shows a simple example for calculating the accuracy using
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target)
acc = torchmetrics.functional.accuracy(preds, target, task='multiclass', num_classes=5)

Module metrics
~~~~~~~~~~~~~~

Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath. The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of the PyTorch module) that allow them to offer additional functionalities:
Nearly all functional metrics have a corresponding class-based metric that calls it a functional counterpart underneath.
The class-based metrics are characterized by having one or more internal metrics states (similar to the parameters of
the PyTorch module) that allow them to offer additional functionalities:

* Accumulation of multiple batches
* Automatic synchronization between multiple devices
Expand All @@ -84,7 +86,7 @@ The code below shows how to use the class-based interface:
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy()
metric = torchmetrics.Accuracy(task='multiclass', num_classes=5)

n_batches = 10
for i in range(n_batches):
Expand Down
Loading