Skip to content

Commit 4c93282

Browse files
Add MetricGroup feature (#3266)
* Initial commit * Add tests * Fix two typos * Fix Mypy * Fix engine mypy issue * Fix docstring * Fix another problem in docstring --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 65352ad commit 4c93282

File tree

5 files changed

+177
-2
lines changed

5 files changed

+177
-2
lines changed

docs/source/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ Complete list of metrics
335335
MeanPairwiseDistance
336336
MeanSquaredError
337337
metric.Metric
338+
metric_group.MetricGroup
338339
metrics_lambda.MetricsLambda
339340
MultiLabelConfusionMatrix
340341
MutualInformation

ignite/engine/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
157157
_check_signature(process_function, "process_function", self, None)
158158

159159
# generator provided by self._internal_run_as_gen
160-
self._internal_run_generator: Optional[Generator] = None
160+
self._internal_run_generator: Optional[Generator[Any, None, State]] = None
161161

162162
def register_events(
163163
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
@@ -951,7 +951,7 @@ def _internal_run(self) -> State:
951951
self._internal_run_generator = None
952952
return out.value
953953

954-
def _internal_run_as_gen(self) -> Generator:
954+
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
955955
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
956956
self._init_timers(self.state)
957957
try:

ignite/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
2323
from ignite.metrics.mean_squared_error import MeanSquaredError
2424
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
25+
from ignite.metrics.metric_group import MetricGroup
2526
from ignite.metrics.metrics_lambda import MetricsLambda
2627
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
2728
from ignite.metrics.mutual_information import MutualInformation
@@ -41,6 +42,7 @@
4142
"Metric",
4243
"Accuracy",
4344
"Loss",
45+
"MetricGroup",
4446
"MetricsLambda",
4547
"MeanAbsoluteError",
4648
"MeanPairwiseDistance",

ignite/metrics/metric_group.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Any, Callable, Dict, Sequence
2+
3+
import torch
4+
5+
from ignite.metrics import Metric
6+
7+
8+
class MetricGroup(Metric):
9+
"""
10+
A class for grouping metrics so that user could manage them easier.
11+
12+
Args:
13+
metrics: a dictionary of names to metric instances.
14+
output_transform: a callable that is used to transform the
15+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
16+
form expected by the metric. `output_transform` of each metric in the group is also
17+
called upon its update.
18+
19+
Examples:
20+
We construct a group of metrics, attach them to the engine at once and retrieve their result.
21+
22+
.. code-block:: python
23+
24+
import torch
25+
26+
metric_group = MetricGroup({'acc': Accuracy(), 'precision': Precision(), 'loss': Loss(nn.NLLLoss())})
27+
metric_group.attach(default_evaluator, "eval_metrics")
28+
y_true = torch.tensor([1, 0, 1, 1, 0, 1])
29+
y_pred = torch.tensor([1, 0, 1, 0, 1, 1])
30+
state = default_evaluator.run([[y_pred, y_true]])
31+
32+
# Metrics individually available in `state.metrics`
33+
state.metrics["acc"], state.metrics["precision"], state.metrics["loss"]
34+
35+
# And also altogether
36+
state.metrics["eval_metrics"]
37+
"""
38+
39+
_state_dict_all_req_keys = ("metrics",)
40+
41+
def __init__(self, metrics: Dict[str, Metric], output_transform: Callable = lambda x: x):
42+
self.metrics = metrics
43+
super(MetricGroup, self).__init__(output_transform=output_transform)
44+
45+
def reset(self) -> None:
46+
for m in self.metrics.values():
47+
m.reset()
48+
49+
def update(self, output: Sequence[torch.Tensor]) -> None:
50+
for m in self.metrics.values():
51+
m.update(m._output_transform(output))
52+
53+
def compute(self) -> Dict[str, Any]:
54+
return {k: m.compute() for k, m in self.metrics.items()}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import pytest
2+
import torch
3+
4+
from ignite import distributed as idist
5+
from ignite.engine import Engine
6+
from ignite.metrics import Accuracy, MetricGroup, Precision
7+
8+
torch.manual_seed(41)
9+
10+
11+
def test_update():
12+
precision = Precision()
13+
accuracy = Accuracy()
14+
15+
group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})
16+
17+
y_pred = torch.randint(0, 2, (100,))
18+
y = torch.randint(0, 2, (100,))
19+
20+
precision.update((y_pred, y))
21+
accuracy.update((y_pred, y))
22+
group.update((y_pred, y))
23+
24+
assert precision.state_dict() == group.metrics["precision"].state_dict()
25+
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()
26+
27+
28+
def test_output_transform():
29+
def drop_first(output):
30+
y_pred, y = output
31+
return (y_pred[1:], y[1:])
32+
33+
precision = Precision(output_transform=drop_first)
34+
accuracy = Accuracy(output_transform=drop_first)
35+
36+
group = MetricGroup(
37+
{"precision": Precision(output_transform=drop_first), "accuracy": Accuracy(output_transform=drop_first)}
38+
)
39+
40+
y_pred = torch.randint(0, 2, (100,))
41+
y = torch.randint(0, 2, (100,))
42+
43+
precision.update(drop_first(drop_first((y_pred, y))))
44+
accuracy.update(drop_first(drop_first((y_pred, y))))
45+
group.update(drop_first((y_pred, y)))
46+
47+
assert precision.state_dict() == group.metrics["precision"].state_dict()
48+
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()
49+
50+
51+
def test_compute():
52+
precision = Precision()
53+
accuracy = Accuracy()
54+
55+
group = MetricGroup({"precision": Precision(), "accuracy": Accuracy()})
56+
57+
for _ in range(3):
58+
y_pred = torch.randint(0, 2, (100,))
59+
y = torch.randint(0, 2, (100,))
60+
61+
precision.update((y_pred, y))
62+
accuracy.update((y_pred, y))
63+
group.update((y_pred, y))
64+
65+
assert group.compute() == {"precision": precision.compute(), "accuracy": accuracy.compute()}
66+
67+
precision.reset()
68+
accuracy.reset()
69+
group.reset()
70+
71+
assert precision.state_dict() == group.metrics["precision"].state_dict()
72+
assert accuracy.state_dict() == group.metrics["accuracy"].state_dict()
73+
74+
75+
@pytest.mark.usefixtures("distributed")
76+
class TestDistributed:
77+
def test_integration(self):
78+
rank = idist.get_rank()
79+
torch.manual_seed(12 + rank)
80+
81+
n_epochs = 3
82+
n_iters = 5
83+
batch_size = 10
84+
device = idist.device()
85+
86+
y_true = torch.randint(0, 2, size=(n_iters * batch_size,)).to(device)
87+
y_pred = torch.randint(0, 2, (n_iters * batch_size,)).to(device)
88+
89+
def update(_, i):
90+
return (
91+
y_pred[i * batch_size : (i + 1) * batch_size],
92+
y_true[i * batch_size : (i + 1) * batch_size],
93+
)
94+
95+
engine = Engine(update)
96+
97+
precision = Precision()
98+
precision.attach(engine, "precision")
99+
100+
accuracy = Accuracy()
101+
accuracy.attach(engine, "accuracy")
102+
103+
group = MetricGroup({"eval_metrics.accuracy": Accuracy(), "eval_metrics.precision": Precision()})
104+
group.attach(engine, "eval_metrics")
105+
106+
data = list(range(n_iters))
107+
engine.run(data=data, max_epochs=n_epochs)
108+
109+
assert "eval_metrics" in engine.state.metrics
110+
assert "eval_metrics.accuracy" in engine.state.metrics
111+
assert "eval_metrics.precision" in engine.state.metrics
112+
113+
assert engine.state.metrics["eval_metrics"] == {
114+
"eval_metrics.accuracy": engine.state.metrics["accuracy"],
115+
"eval_metrics.precision": engine.state.metrics["precision"],
116+
}
117+
assert engine.state.metrics["eval_metrics.accuracy"] == engine.state.metrics["accuracy"]
118+
assert engine.state.metrics["eval_metrics.precision"] == engine.state.metrics["precision"]

0 commit comments

Comments
 (0)