Skip to content

Commit

Permalink
Merge branch 'master' into dependabot-github_actions-pypa-gh-action-p…
Browse files Browse the repository at this point in the history
…ypi-publish-1.6.4
  • Loading branch information
Borda authored Jan 3, 2023
2 parents 953b038 + 9c5a7b2 commit 9de26a7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
22 changes: 11 additions & 11 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ The example below shows how to use a metric in your `LightningModule <https://py

class MyModel(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.accuracy = torchmetrics.Accuracy(task='multiclass')
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -78,10 +78,10 @@ value by calling ``.compute()``.

class MyModule(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def training_step(self, batch, batch_idx):
x, y = batch
Expand All @@ -105,8 +105,8 @@ of the metrics.

def __init__(self):
...
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -141,9 +141,9 @@ mixed as it can lead to wrong results.

class MyModule(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.valid_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def validation_step(self, batch, batch_idx):
logits = self(x)
Expand Down Expand Up @@ -185,9 +185,9 @@ The following contains a list of pitfalls to be aware of:

class MyModule(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.val_acc = nn.ModuleList([torchmetrics.Accuracy(task='multiclass') for _ in range(2)])
self.val_acc = nn.ModuleList([torchmetrics.Accuracy(task="multiclass", num_classes=num_classes) for _ in range(2)])

def val_dataloader(self):
return [DataLoader(...), DataLoader(...)]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ inside your LightningModule. In most cases we just have to replace ``self.log``
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall

class MyModule(LightningModule):
def __init__(self):
def __init__(self, num_classes):
metrics = MetricCollection([
MulticlassAccuracy(), MulticlassPrecision(), MulticlassRecall()
MulticlassAccuracy(num_classes), MulticlassPrecision(num_classes), MulticlassRecall(num_classes)
])
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')
Expand Down
4 changes: 2 additions & 2 deletions docs/source/pages/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ The code-snippet below shows a simple example for calculating the accuracy using
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target, task='multiclass', num_classes=5)
acc = torchmetrics.functional.accuracy(preds, target, task="multiclass", num_classes=5)

Module metrics
~~~~~~~~~~~~~~
Expand All @@ -86,7 +86,7 @@ The code below shows how to use the class-based interface:
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy(task='multiclass', num_classes=5)
metric = torchmetrics.Accuracy(task="multiclass", num_classes=5)

n_batches = 10
for i in range(n_batches):
Expand Down

0 comments on commit 9de26a7

Please sign in to comment.