Skip to content

Commit 2c79b7e

Browse files
authored
Add MutualInformation Metric (#3230)
* add MutualInformationMetric * update test for MutualInformation metric * format code for MutualInformation Metric * update test for MutualInformation metric * update test * update docstring * fix device compatibility * fix test_accumulator_device for MutualInformation metric * update doc * modify docstring * modify formula of docstring * update formula of docstring * update formula of docstring * remove unused import * add reference * commonalize redundant code * modify decorator * add a comment * fix decorator
1 parent 6177c80 commit 2c79b7e

File tree

5 files changed

+247
-1
lines changed

5 files changed

+247
-1
lines changed

docs/source/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ Complete list of metrics
337337
metric.Metric
338338
metrics_lambda.MetricsLambda
339339
MultiLabelConfusionMatrix
340+
MutualInformation
340341
precision.Precision
341342
PSNR
342343
recall.Recall

ignite/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
2222
from ignite.metrics.metrics_lambda import MetricsLambda
2323
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix
24+
from ignite.metrics.mutual_information import MutualInformation
2425
from ignite.metrics.nlp.bleu import Bleu
2526
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
2627
from ignite.metrics.precision import Precision
@@ -57,6 +58,7 @@
5758
"mIoU",
5859
"JaccardIndex",
5960
"MultiLabelConfusionMatrix",
61+
"MutualInformation",
6062
"Precision",
6163
"PSNR",
6264
"Recall",

ignite/metrics/entropy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,13 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
8080

8181
prob = F.softmax(y_pred, dim=1)
8282
log_prob = F.log_softmax(y_pred, dim=1)
83+
84+
self._update(prob, log_prob)
85+
86+
def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None:
8387
entropy_sum = -torch.sum(prob * log_prob)
8488
self._sum_of_entropies += entropy_sum.to(self._device)
85-
self._num_examples += y_pred.shape[0]
89+
self._num_examples += prob.shape[0]
8690

8791
@sync_all_reduce("_sum_of_entropies", "_num_examples")
8892
def compute(self) -> float:
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import torch
2+
3+
from ignite.exceptions import NotComputableError
4+
from ignite.metrics import Entropy
5+
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
6+
7+
__all__ = ["MutualInformation"]
8+
9+
10+
class MutualInformation(Entropy):
11+
r"""Calculates the `mutual information <https://en.wikipedia.org/wiki/Mutual_information>`_
12+
between input :math:`X` and prediction :math:`Y`.
13+
14+
.. math::
15+
\begin{align*}
16+
I(X;Y) &= H(Y) - H(Y|X) = H \left( \frac{1}{N}\sum_{i=1}^N \hat{\mathbf{p}}_i \right)
17+
- \frac{1}{N}\sum_{i=1}^N H(\hat{\mathbf{p}}_i), \\
18+
H(\mathbf{p}) &= -\sum_{c=1}^C p_c \log p_c.
19+
\end{align*}
20+
21+
where :math:`\hat{\mathbf{p}}_i` is the prediction probability vector for :math:`i`-th input,
22+
and :math:`H(\mathbf{p})` is the entropy of :math:`\mathbf{p}`.
23+
24+
Intuitively, this metric measures how well input data are clustered by classes in the feature space [1].
25+
26+
[1] https://proceedings.mlr.press/v70/hu17b.html
27+
28+
- ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric.
29+
- ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification)
30+
or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed.
31+
32+
Args:
33+
output_transform: a callable that is used to transform the
34+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
35+
form expected by the metric. This can be useful if, for example, you have a multi-output model and
36+
you want to compute the metric with respect to one of the outputs.
37+
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
38+
device: specifies which device updates are accumulated on. Setting the
39+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
40+
non-blocking. By default, CPU.
41+
42+
Examples:
43+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
44+
The output of the engine's ``process_function`` needs to be in the format of
45+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
46+
to the metric to transform the output into the form expected by the metric.
47+
48+
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
49+
50+
.. include:: defaults.rst
51+
:start-after: :orphan:
52+
53+
.. testcode::
54+
55+
metric = MutualInformation()
56+
metric.attach(default_evaluator, 'mutual_information')
57+
y_true = torch.tensor([0, 1, 2]) # not considered in the MutualInformation metric.
58+
y_pred = torch.tensor([
59+
[ 0.0000, 0.6931, 1.0986],
60+
[ 1.3863, 1.6094, 1.6094],
61+
[ 0.0000, -2.3026, -2.3026]
62+
])
63+
state = default_evaluator.run([[y_pred, y_true]])
64+
print(state.metrics['mutual_information'])
65+
66+
.. testoutput::
67+
68+
0.18599730730056763
69+
"""
70+
71+
_state_dict_all_req_keys = ("_sum_of_probabilities",)
72+
73+
@reinit__is_reduced
74+
def reset(self) -> None:
75+
super().reset()
76+
self._sum_of_probabilities = torch.tensor(0.0, device=self._device)
77+
78+
def _update(self, prob: torch.Tensor, log_prob: torch.Tensor) -> None:
79+
super()._update(prob, log_prob)
80+
# We can't use += below as _sum_of_probabilities can be a scalar and prob.sum(dim=0) is a vector
81+
self._sum_of_probabilities = self._sum_of_probabilities + prob.sum(dim=0).to(self._device)
82+
83+
@sync_all_reduce("_sum_of_probabilities", "_sum_of_entropies", "_num_examples")
84+
def compute(self) -> float:
85+
n = self._num_examples
86+
if n == 0:
87+
raise NotComputableError("MutualInformation must have at least one example before it can be computed.")
88+
89+
marginal_prob = self._sum_of_probabilities / n
90+
marginal_ent = -(marginal_prob * torch.log(marginal_prob)).sum()
91+
conditional_ent = self._sum_of_entropies / n
92+
mi = marginal_ent - conditional_ent
93+
mi = torch.clamp(mi, min=0.0) # mutual information cannot be negative
94+
return float(mi.item())
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from typing import Tuple
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from scipy.special import softmax
7+
from scipy.stats import entropy
8+
from torch import Tensor
9+
10+
import ignite.distributed as idist
11+
12+
from ignite.engine import Engine
13+
from ignite.exceptions import NotComputableError
14+
from ignite.metrics import MutualInformation
15+
16+
17+
def np_mutual_information(np_y_pred: np.ndarray) -> float:
18+
prob = softmax(np_y_pred, axis=1)
19+
marginal_ent = entropy(np.mean(prob, axis=0))
20+
conditional_ent = np.mean(entropy(prob, axis=1))
21+
return max(0.0, marginal_ent - conditional_ent)
22+
23+
24+
def test_zero_sample():
25+
mi = MutualInformation()
26+
with pytest.raises(
27+
NotComputableError, match=r"MutualInformation must have at least one example before it can be computed"
28+
):
29+
mi.compute()
30+
31+
32+
def test_invalid_shape():
33+
mi = MutualInformation()
34+
y_pred = torch.randn(10).float()
35+
with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"):
36+
mi.update((y_pred, None))
37+
38+
39+
@pytest.fixture(params=list(range(4)))
40+
def test_case(request):
41+
return [
42+
(torch.randn((100, 10)).float(), torch.randint(0, 10, size=[100]), 1),
43+
(torch.rand((100, 500)).float(), torch.randint(0, 500, size=[100]), 1),
44+
# updated batches
45+
(torch.normal(0.0, 5.0, size=(100, 10)).float(), torch.randint(0, 10, size=[100]), 16),
46+
(torch.normal(5.0, 3.0, size=(100, 200)).float(), torch.randint(0, 200, size=[100]), 16),
47+
# image segmentation
48+
(torch.randn((100, 5, 32, 32)).float(), torch.randint(0, 5, size=(100, 32, 32)), 16),
49+
(torch.randn((100, 5, 224, 224)).float(), torch.randint(0, 5, size=(100, 224, 224)), 16),
50+
][request.param]
51+
52+
53+
@pytest.mark.parametrize("n_times", range(5))
54+
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]):
55+
mi = MutualInformation()
56+
57+
y_pred, y, batch_size = test_case
58+
59+
mi.reset()
60+
if batch_size > 1:
61+
n_iters = y.shape[0] // batch_size + 1
62+
for i in range(n_iters):
63+
idx = i * batch_size
64+
mi.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
65+
else:
66+
mi.update((y_pred, y))
67+
68+
np_res = np_mutual_information(y_pred.numpy())
69+
res = mi.compute()
70+
71+
assert isinstance(res, float)
72+
assert pytest.approx(np_res, rel=1e-4) == res
73+
74+
75+
def test_accumulator_detached():
76+
mi = MutualInformation()
77+
78+
y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True)
79+
y = torch.zeros(2)
80+
mi.update((y_pred, y))
81+
82+
assert not mi._sum_of_probabilities.requires_grad
83+
84+
85+
@pytest.mark.usefixtures("distributed")
86+
class TestDistributed:
87+
def test_integration(self):
88+
tol = 1e-4
89+
n_iters = 100
90+
batch_size = 10
91+
n_cls = 50
92+
device = idist.device()
93+
rank = idist.get_rank()
94+
torch.manual_seed(12 + rank)
95+
96+
metric_devices = [torch.device("cpu")]
97+
if device.type != "xla":
98+
metric_devices.append(device)
99+
100+
for metric_device in metric_devices:
101+
y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device)
102+
y_preds = torch.normal(0.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device)
103+
104+
engine = Engine(
105+
lambda e, i: (
106+
y_preds[i * batch_size : (i + 1) * batch_size],
107+
y_true[i * batch_size : (i + 1) * batch_size],
108+
)
109+
)
110+
111+
m = MutualInformation(device=metric_device)
112+
m.attach(engine, "mutual_information")
113+
114+
data = list(range(n_iters))
115+
engine.run(data=data, max_epochs=1)
116+
117+
y_preds = idist.all_gather(y_preds)
118+
y_true = idist.all_gather(y_true)
119+
120+
assert "mutual_information" in engine.state.metrics
121+
res = engine.state.metrics["mutual_information"]
122+
123+
true_res = np_mutual_information(y_preds.cpu().numpy())
124+
125+
assert pytest.approx(true_res, rel=tol) == res
126+
127+
def test_accumulator_device(self):
128+
device = idist.device()
129+
metric_devices = [torch.device("cpu")]
130+
if device.type != "xla":
131+
metric_devices.append(device)
132+
for metric_device in metric_devices:
133+
mi = MutualInformation(device=metric_device)
134+
135+
devices = (mi._device, mi._sum_of_probabilities.device)
136+
for dev in devices:
137+
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
138+
139+
y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True)
140+
y = torch.zeros(2)
141+
mi.update((y_pred, y))
142+
143+
devices = (mi._device, mi._sum_of_probabilities.device)
144+
for dev in devices:
145+
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

0 commit comments

Comments
 (0)