Skip to content

Commit

Permalink
rename example metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
SunQpark committed Sep 10, 2019
1 parent 41505c2 commit 206c52a
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Config files are in `.json` format:
},
"loss": "nll_loss", // loss
"metrics": [
"my_metric", "my_metric2" // list of metrics to evaluate
"accuracy", "top_k_acc" // list of metrics to evaluate
],
"lr_scheduler": {
"type": "StepLR", // learning rate scheduler
Expand Down Expand Up @@ -285,7 +285,7 @@ Metric functions are located in 'model/metric.py'.

You can monitor multiple metrics by providing a list in the configuration file, e.g.:
```json
"metrics": ["my_metric", "my_metric2"],
"metrics": ["accuracy", "top_k_acc"],
```

### Additional logging
Expand Down
2 changes: 1 addition & 1 deletion config.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
},
"loss": "nll_loss",
"metrics": [
"my_metric", "my_metric2"
"accuracy", "top_k_acc"
],
"lr_scheduler": {
"type": "StepLR",
Expand Down
4 changes: 2 additions & 2 deletions model/metric.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch


def my_metric(output, target):
def accuracy(output, target):
with torch.no_grad():
pred = torch.argmax(output, dim=1)
assert pred.shape[0] == len(target)
Expand All @@ -10,7 +10,7 @@ def my_metric(output, target):
return correct / len(target)


def my_metric2(output, target, k=3):
def top_k_acc(output, target, k=3):
with torch.no_grad():
pred = torch.topk(output, k, dim=1)[1]
assert pred.shape[0] == len(target)
Expand Down

0 comments on commit 206c52a

Please sign in to comment.