1
+ import os
2
+ from typing import Tuple
1
3
from unittest .mock import patch
2
4
3
5
import numpy as np
6
8
import torch
7
9
from sklearn .metrics import precision_recall_curve
8
10
11
+ import ignite .distributed as idist
9
12
from ignite .contrib .metrics .precision_recall_curve import PrecisionRecallCurve
10
13
from ignite .engine import Engine
11
14
from ignite .metrics .epoch_metric import EpochMetricWarning
@@ -38,9 +41,12 @@ def test_precision_recall_curve():
38
41
39
42
precision_recall_curve_metric .update ((y_pred , y ))
40
43
precision , recall , thresholds = precision_recall_curve_metric .compute ()
44
+ precision = precision .numpy ()
45
+ recall = recall .numpy ()
46
+ thresholds = thresholds .numpy ()
41
47
42
- assert np . array_equal (precision , sk_precision )
43
- assert np . array_equal (recall , sk_recall )
48
+ assert pytest . approx (precision ) == sk_precision
49
+ assert pytest . approx (recall ) == sk_recall
44
50
# assert thresholds almost equal, due to numpy->torch->numpy conversion
45
51
np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
46
52
@@ -70,9 +76,11 @@ def update_fn(engine, batch):
70
76
71
77
data = list (range (size // batch_size ))
72
78
precision , recall , thresholds = engine .run (data , max_epochs = 1 ).metrics ["precision_recall_curve" ]
73
-
74
- assert np .array_equal (precision , sk_precision )
75
- assert np .array_equal (recall , sk_recall )
79
+ precision = precision .numpy ()
80
+ recall = recall .numpy ()
81
+ thresholds = thresholds .numpy ()
82
+ assert pytest .approx (precision ) == sk_precision
83
+ assert pytest .approx (recall ) == sk_recall
76
84
# assert thresholds almost equal, due to numpy->torch->numpy conversion
77
85
np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
78
86
@@ -103,9 +111,12 @@ def update_fn(engine, batch):
103
111
104
112
data = list (range (size // batch_size ))
105
113
precision , recall , thresholds = engine .run (data , max_epochs = 1 ).metrics ["precision_recall_curve" ]
114
+ precision = precision .numpy ()
115
+ recall = recall .numpy ()
116
+ thresholds = thresholds .numpy ()
106
117
107
- assert np . array_equal (precision , sk_precision )
108
- assert np . array_equal (recall , sk_recall )
118
+ assert pytest . approx (precision ) == sk_precision
119
+ assert pytest . approx (recall ) == sk_recall
109
120
# assert thresholds almost equal, due to numpy->torch->numpy conversion
110
121
np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
111
122
@@ -124,3 +135,182 @@ def test_check_compute_fn():
124
135
125
136
em = PrecisionRecallCurve (check_compute_fn = False )
126
137
em .update (output )
138
+
139
+
140
+ def _test_distrib_compute (device ):
141
+
142
+ rank = idist .get_rank ()
143
+ torch .manual_seed (12 )
144
+
145
+ def _test (y_pred , y , batch_size , metric_device ):
146
+
147
+ metric_device = torch .device (metric_device )
148
+ prc = PrecisionRecallCurve (device = metric_device )
149
+
150
+ torch .manual_seed (10 + rank )
151
+
152
+ prc .reset ()
153
+ if batch_size > 1 :
154
+ n_iters = y .shape [0 ] // batch_size + 1
155
+ for i in range (n_iters ):
156
+ idx = i * batch_size
157
+ prc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
158
+ else :
159
+ prc .update ((y_pred , y ))
160
+
161
+ # gather y_pred, y
162
+ y_pred = idist .all_gather (y_pred )
163
+ y = idist .all_gather (y )
164
+
165
+ np_y = y .cpu ().numpy ()
166
+ np_y_pred = y_pred .cpu ().numpy ()
167
+
168
+ res = prc .compute ()
169
+
170
+ assert isinstance (res , Tuple )
171
+ assert precision_recall_curve (np_y , np_y_pred )[0 ] == pytest .approx (res [0 ])
172
+ assert precision_recall_curve (np_y , np_y_pred )[1 ] == pytest .approx (res [1 ])
173
+ assert precision_recall_curve (np_y , np_y_pred )[2 ] == pytest .approx (res [2 ])
174
+
175
+ def get_test_cases ():
176
+ test_cases = [
177
+ # Binary input data of shape (N,) or (N, 1)
178
+ (torch .randint (0 , 2 , size = (10 ,)), torch .randint (0 , 2 , size = (10 ,)), 1 ),
179
+ (torch .randint (0 , 2 , size = (10 , 1 )), torch .randint (0 , 2 , size = (10 , 1 )), 1 ),
180
+ # updated batches
181
+ (torch .randint (0 , 2 , size = (50 ,)), torch .randint (0 , 2 , size = (50 ,)), 16 ),
182
+ (torch .randint (0 , 2 , size = (50 , 1 )), torch .randint (0 , 2 , size = (50 , 1 )), 16 ),
183
+ ]
184
+ return test_cases
185
+
186
+ for _ in range (5 ):
187
+ test_cases = get_test_cases ()
188
+ for y_pred , y , batch_size in test_cases :
189
+ _test (y_pred , y , batch_size , "cpu" )
190
+ if device .type != "xla" :
191
+ _test (y_pred , y , batch_size , idist .device ())
192
+
193
+
194
+ def _test_distrib_integration (device ):
195
+
196
+ rank = idist .get_rank ()
197
+ torch .manual_seed (12 )
198
+
199
+ def _test (n_epochs , metric_device ):
200
+ metric_device = torch .device (metric_device )
201
+ n_iters = 80
202
+ size = 151
203
+ y_true = torch .randint (0 , 2 , (size ,)).to (device )
204
+ y_preds = torch .randint (0 , 2 , (size ,)).to (device )
205
+
206
+ def update (engine , i ):
207
+ return (
208
+ y_preds [i * size : (i + 1 ) * size ],
209
+ y_true [i * size : (i + 1 ) * size ],
210
+ )
211
+
212
+ engine = Engine (update )
213
+
214
+ prc = PrecisionRecallCurve (device = metric_device )
215
+ prc .attach (engine , "prc" )
216
+
217
+ data = list (range (n_iters ))
218
+ engine .run (data = data , max_epochs = n_epochs )
219
+
220
+ assert "prc" in engine .state .metrics
221
+
222
+ precision , recall , thresholds = engine .state .metrics ["prc" ]
223
+
224
+ np_y_true = y_true .cpu ().numpy ().ravel ()
225
+ np_y_preds = y_preds .cpu ().numpy ().ravel ()
226
+
227
+ sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y_true , np_y_preds )
228
+
229
+ assert precision .shape == sk_precision .shape
230
+ assert recall .shape == sk_recall .shape
231
+ assert thresholds .shape == sk_thresholds .shape
232
+ assert pytest .approx (precision ) == sk_precision
233
+ assert pytest .approx (recall ) == sk_recall
234
+ assert pytest .approx (thresholds ) == sk_thresholds
235
+
236
+ metric_devices = ["cpu" ]
237
+ if device .type != "xla" :
238
+ metric_devices .append (idist .device ())
239
+ for metric_device in metric_devices :
240
+ for _ in range (2 ):
241
+ _test (n_epochs = 1 , metric_device = metric_device )
242
+ _test (n_epochs = 2 , metric_device = metric_device )
243
+
244
+
245
+ @pytest .mark .distributed
246
+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
247
+ @pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Skip if no GPU" )
248
+ def test_distrib_nccl_gpu (distributed_context_single_node_nccl ):
249
+
250
+ device = idist .device ()
251
+ _test_distrib_compute (device )
252
+ _test_distrib_integration (device )
253
+
254
+
255
+ @pytest .mark .distributed
256
+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
257
+ def test_distrib_gloo_cpu_or_gpu (distributed_context_single_node_gloo ):
258
+
259
+ device = idist .device ()
260
+ _test_distrib_compute (device )
261
+ _test_distrib_integration (device )
262
+
263
+
264
+ @pytest .mark .distributed
265
+ @pytest .mark .skipif (not idist .has_hvd_support , reason = "Skip if no Horovod dist support" )
266
+ @pytest .mark .skipif ("WORLD_SIZE" in os .environ , reason = "Skip if launched as multiproc" )
267
+ def test_distrib_hvd (gloo_hvd_executor ):
268
+
269
+ device = torch .device ("cpu" if not torch .cuda .is_available () else "cuda" )
270
+ nproc = 4 if not torch .cuda .is_available () else torch .cuda .device_count ()
271
+
272
+ gloo_hvd_executor (_test_distrib_compute , (device ,), np = nproc , do_init = True )
273
+ gloo_hvd_executor (_test_distrib_integration , (device ,), np = nproc , do_init = True )
274
+
275
+
276
+ @pytest .mark .multinode_distributed
277
+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
278
+ @pytest .mark .skipif ("MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
279
+ def test_multinode_distrib_gloo_cpu_or_gpu (distributed_context_multi_node_gloo ):
280
+
281
+ device = idist .device ()
282
+ _test_distrib_compute (device )
283
+ _test_distrib_integration (device )
284
+
285
+
286
+ @pytest .mark .multinode_distributed
287
+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
288
+ @pytest .mark .skipif ("GPU_MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
289
+ def test_multinode_distrib_nccl_gpu (distributed_context_multi_node_nccl ):
290
+
291
+ device = idist .device ()
292
+ _test_distrib_compute (device )
293
+ _test_distrib_integration (device )
294
+
295
+
296
+ @pytest .mark .tpu
297
+ @pytest .mark .skipif ("NUM_TPU_WORKERS" in os .environ , reason = "Skip if NUM_TPU_WORKERS is in env vars" )
298
+ @pytest .mark .skipif (not idist .has_xla_support , reason = "Skip if no PyTorch XLA package" )
299
+ def test_distrib_single_device_xla ():
300
+ device = idist .device ()
301
+ _test_distrib_compute (device )
302
+ _test_distrib_integration (device )
303
+
304
+
305
+ def _test_distrib_xla_nprocs (index ):
306
+ device = idist .device ()
307
+ _test_distrib_compute (device )
308
+ _test_distrib_integration (device )
309
+
310
+
311
+ @pytest .mark .tpu
312
+ @pytest .mark .skipif ("NUM_TPU_WORKERS" not in os .environ , reason = "Skip if no NUM_TPU_WORKERS in env vars" )
313
+ @pytest .mark .skipif (not idist .has_xla_support , reason = "Skip if no PyTorch XLA package" )
314
+ def test_distrib_xla_nprocs (xmp_executor ):
315
+ n = int (os .environ ["NUM_TPU_WORKERS" ])
316
+ xmp_executor (_test_distrib_xla_nprocs , args = (), nprocs = n )
0 commit comments