Skip to content

Commit ccbd6af

Browse files
sayantan1410sdesrozisDesroziers
authored
Updated precision_recall_curve.py (#2490)
* updated precision_recall_curve.py * autopep8 fix * removed unsed imports * made some small changes * solved unused import issue * autopep8 fix * reverted back some changes and changed epoch_metric.py * autopep8 fix * re written compute function for precision_recall_curve.py * reverted back epoch_metric.py * reverted back unnecessary changes to doc string * reverted a line break that was added by mistake * autopep8 fix * corrected function annotation * fixed mypy issues * Added tests for GPU and TPU * autopep8 fix * fixed a few tests in precision_recall_curve * autopep8 fix * fixed a few errors for the tests * autopep8 fix * added tests for array shape * autopep8 fix * made some small changes * Fixed all the errors in the tests * fix distributed computation * converted tensors to numpy array * checking for approx equal Co-authored-by: sayantan1410 <sayantan1410@users.noreply.github.com> Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com> Co-authored-by: Desroziers <sylvain.desroziers@michelin.com>
1 parent 3d9ba50 commit ccbd6af

File tree

3 files changed

+245
-12
lines changed

3 files changed

+245
-12
lines changed

ignite/contrib/metrics/precision_recall_curve.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import Any, Callable, Tuple
1+
from typing import Any, Callable, cast, Tuple, Union
22

33
import torch
44

5+
import ignite.distributed as idist
6+
from ignite.exceptions import NotComputableError
57
from ignite.metrics import EpochMetric
68

79

@@ -69,7 +71,48 @@ def sigmoid_output_transform(output):
6971
7072
"""
7173

72-
def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None:
74+
def __init__(
75+
self,
76+
output_transform: Callable = lambda x: x,
77+
check_compute_fn: bool = False,
78+
device: Union[str, torch.device] = torch.device("cpu"),
79+
) -> None:
7380
super(PrecisionRecallCurve, self).__init__(
74-
precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
81+
precision_recall_curve_compute_fn,
82+
output_transform=output_transform,
83+
check_compute_fn=check_compute_fn,
84+
device=device,
7585
)
86+
87+
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
88+
if len(self._predictions) < 1 or len(self._targets) < 1:
89+
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")
90+
91+
_prediction_tensor = torch.cat(self._predictions, dim=0)
92+
_target_tensor = torch.cat(self._targets, dim=0)
93+
94+
ws = idist.get_world_size()
95+
if ws > 1 and not self._is_reduced:
96+
# All gather across all processes
97+
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
98+
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
99+
self._is_reduced = True
100+
101+
if idist.get_rank() == 0:
102+
# Run compute_fn on zero rank only
103+
precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
104+
precision = torch.Tensor(precision)
105+
recall = torch.Tensor(recall)
106+
# thresholds can have negative strides, not compatible with torch tensors
107+
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
108+
thresholds = torch.Tensor(thresholds.copy())
109+
else:
110+
precision, recall, thresholds = None, None, None
111+
112+
if ws > 1:
113+
# broadcast result to all processes
114+
precision = idist.broadcast(precision, src=0, safe_mode=True)
115+
recall = idist.broadcast(recall, src=0, safe_mode=True)
116+
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)
117+
118+
return precision, recall, thresholds

ignite/metrics/epoch_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Callable, cast, List, Tuple, Union
2+
from typing import Any, Callable, cast, List, Tuple, Union
33

44
import torch
55

@@ -136,7 +136,7 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
136136
except Exception as e:
137137
warnings.warn(f"Probably, there can be a problem with `compute_fn`:\n {e}.", EpochMetricWarning)
138138

139-
def compute(self) -> float:
139+
def compute(self) -> Any:
140140
if len(self._predictions) < 1 or len(self._targets) < 1:
141141
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")
142142

tests/ignite/contrib/metrics/test_precision_recall_curve.py

Lines changed: 197 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
from typing import Tuple
13
from unittest.mock import patch
24

35
import numpy as np
@@ -6,6 +8,7 @@
68
import torch
79
from sklearn.metrics import precision_recall_curve
810

11+
import ignite.distributed as idist
912
from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve
1013
from ignite.engine import Engine
1114
from ignite.metrics.epoch_metric import EpochMetricWarning
@@ -38,9 +41,12 @@ def test_precision_recall_curve():
3841

3942
precision_recall_curve_metric.update((y_pred, y))
4043
precision, recall, thresholds = precision_recall_curve_metric.compute()
44+
precision = precision.numpy()
45+
recall = recall.numpy()
46+
thresholds = thresholds.numpy()
4147

42-
assert np.array_equal(precision, sk_precision)
43-
assert np.array_equal(recall, sk_recall)
48+
assert pytest.approx(precision) == sk_precision
49+
assert pytest.approx(recall) == sk_recall
4450
# assert thresholds almost equal, due to numpy->torch->numpy conversion
4551
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
4652

@@ -70,9 +76,11 @@ def update_fn(engine, batch):
7076

7177
data = list(range(size // batch_size))
7278
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
73-
74-
assert np.array_equal(precision, sk_precision)
75-
assert np.array_equal(recall, sk_recall)
79+
precision = precision.numpy()
80+
recall = recall.numpy()
81+
thresholds = thresholds.numpy()
82+
assert pytest.approx(precision) == sk_precision
83+
assert pytest.approx(recall) == sk_recall
7684
# assert thresholds almost equal, due to numpy->torch->numpy conversion
7785
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
7886

@@ -103,9 +111,12 @@ def update_fn(engine, batch):
103111

104112
data = list(range(size // batch_size))
105113
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
114+
precision = precision.numpy()
115+
recall = recall.numpy()
116+
thresholds = thresholds.numpy()
106117

107-
assert np.array_equal(precision, sk_precision)
108-
assert np.array_equal(recall, sk_recall)
118+
assert pytest.approx(precision) == sk_precision
119+
assert pytest.approx(recall) == sk_recall
109120
# assert thresholds almost equal, due to numpy->torch->numpy conversion
110121
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
111122

@@ -124,3 +135,182 @@ def test_check_compute_fn():
124135

125136
em = PrecisionRecallCurve(check_compute_fn=False)
126137
em.update(output)
138+
139+
140+
def _test_distrib_compute(device):
141+
142+
rank = idist.get_rank()
143+
torch.manual_seed(12)
144+
145+
def _test(y_pred, y, batch_size, metric_device):
146+
147+
metric_device = torch.device(metric_device)
148+
prc = PrecisionRecallCurve(device=metric_device)
149+
150+
torch.manual_seed(10 + rank)
151+
152+
prc.reset()
153+
if batch_size > 1:
154+
n_iters = y.shape[0] // batch_size + 1
155+
for i in range(n_iters):
156+
idx = i * batch_size
157+
prc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
158+
else:
159+
prc.update((y_pred, y))
160+
161+
# gather y_pred, y
162+
y_pred = idist.all_gather(y_pred)
163+
y = idist.all_gather(y)
164+
165+
np_y = y.cpu().numpy()
166+
np_y_pred = y_pred.cpu().numpy()
167+
168+
res = prc.compute()
169+
170+
assert isinstance(res, Tuple)
171+
assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0])
172+
assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1])
173+
assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2])
174+
175+
def get_test_cases():
176+
test_cases = [
177+
# Binary input data of shape (N,) or (N, 1)
178+
(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)), 1),
179+
(torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1),
180+
# updated batches
181+
(torch.randint(0, 2, size=(50,)), torch.randint(0, 2, size=(50,)), 16),
182+
(torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16),
183+
]
184+
return test_cases
185+
186+
for _ in range(5):
187+
test_cases = get_test_cases()
188+
for y_pred, y, batch_size in test_cases:
189+
_test(y_pred, y, batch_size, "cpu")
190+
if device.type != "xla":
191+
_test(y_pred, y, batch_size, idist.device())
192+
193+
194+
def _test_distrib_integration(device):
195+
196+
rank = idist.get_rank()
197+
torch.manual_seed(12)
198+
199+
def _test(n_epochs, metric_device):
200+
metric_device = torch.device(metric_device)
201+
n_iters = 80
202+
size = 151
203+
y_true = torch.randint(0, 2, (size,)).to(device)
204+
y_preds = torch.randint(0, 2, (size,)).to(device)
205+
206+
def update(engine, i):
207+
return (
208+
y_preds[i * size : (i + 1) * size],
209+
y_true[i * size : (i + 1) * size],
210+
)
211+
212+
engine = Engine(update)
213+
214+
prc = PrecisionRecallCurve(device=metric_device)
215+
prc.attach(engine, "prc")
216+
217+
data = list(range(n_iters))
218+
engine.run(data=data, max_epochs=n_epochs)
219+
220+
assert "prc" in engine.state.metrics
221+
222+
precision, recall, thresholds = engine.state.metrics["prc"]
223+
224+
np_y_true = y_true.cpu().numpy().ravel()
225+
np_y_preds = y_preds.cpu().numpy().ravel()
226+
227+
sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)
228+
229+
assert precision.shape == sk_precision.shape
230+
assert recall.shape == sk_recall.shape
231+
assert thresholds.shape == sk_thresholds.shape
232+
assert pytest.approx(precision) == sk_precision
233+
assert pytest.approx(recall) == sk_recall
234+
assert pytest.approx(thresholds) == sk_thresholds
235+
236+
metric_devices = ["cpu"]
237+
if device.type != "xla":
238+
metric_devices.append(idist.device())
239+
for metric_device in metric_devices:
240+
for _ in range(2):
241+
_test(n_epochs=1, metric_device=metric_device)
242+
_test(n_epochs=2, metric_device=metric_device)
243+
244+
245+
@pytest.mark.distributed
246+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
247+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
248+
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
249+
250+
device = idist.device()
251+
_test_distrib_compute(device)
252+
_test_distrib_integration(device)
253+
254+
255+
@pytest.mark.distributed
256+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
257+
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
258+
259+
device = idist.device()
260+
_test_distrib_compute(device)
261+
_test_distrib_integration(device)
262+
263+
264+
@pytest.mark.distributed
265+
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
266+
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
267+
def test_distrib_hvd(gloo_hvd_executor):
268+
269+
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
270+
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
271+
272+
gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True)
273+
gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)
274+
275+
276+
@pytest.mark.multinode_distributed
277+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
278+
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
279+
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
280+
281+
device = idist.device()
282+
_test_distrib_compute(device)
283+
_test_distrib_integration(device)
284+
285+
286+
@pytest.mark.multinode_distributed
287+
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
288+
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
289+
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
290+
291+
device = idist.device()
292+
_test_distrib_compute(device)
293+
_test_distrib_integration(device)
294+
295+
296+
@pytest.mark.tpu
297+
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
298+
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
299+
def test_distrib_single_device_xla():
300+
device = idist.device()
301+
_test_distrib_compute(device)
302+
_test_distrib_integration(device)
303+
304+
305+
def _test_distrib_xla_nprocs(index):
306+
device = idist.device()
307+
_test_distrib_compute(device)
308+
_test_distrib_integration(device)
309+
310+
311+
@pytest.mark.tpu
312+
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
313+
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
314+
def test_distrib_xla_nprocs(xmp_executor):
315+
n = int(os.environ["NUM_TPU_WORKERS"])
316+
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)

0 commit comments

Comments
 (0)