@@ -437,29 +437,25 @@ def compute_true_somemetric(y_pred, y):
437
437
def _test_distrib_integration (device ):
438
438
439
439
rank = idist .get_rank ()
440
- np .random .seed (12 )
441
440
442
441
n_iters = 10
443
442
batch_size = 10
444
443
n_classes = 10
445
444
446
445
def _test (metric_device ):
447
- y_true = np .arange (0 , n_iters * batch_size * idist . get_world_size () , dtype = " int64" ) % n_classes
448
- y_pred = 0.2 * np . random . rand (n_iters * batch_size * idist . get_world_size () , n_classes )
449
- for i in range (n_iters * batch_size * idist . get_world_size () ):
446
+ y_true = torch .arange (0 , n_iters * batch_size , dtype = torch . int64 ). to ( device ) % n_classes
447
+ y_pred = 0.2 * torch . rand (n_iters * batch_size , n_classes ). to ( device )
448
+ for i in range (n_iters * batch_size ):
450
449
if np .random .rand () > 0.4 :
451
450
y_pred [i , y_true [i ]] = 1.0
452
451
else :
453
452
j = np .random .randint (0 , n_classes )
454
453
y_pred [i , j ] = 0.7
455
454
456
- y_true = y_true .reshape (n_iters * idist .get_world_size (), batch_size )
457
- y_pred = y_pred .reshape (n_iters * idist .get_world_size (), batch_size , n_classes )
458
-
459
455
def update_fn (engine , i ):
460
- y_true_batch = y_true [i + rank * n_iters , ...]
461
- y_pred_batch = y_pred [i + rank * n_iters , ...]
462
- return torch . from_numpy ( y_pred_batch ), torch . from_numpy ( y_true_batch )
456
+ y_true_batch = y_true [i * batch_size : ( i + 1 ) * batch_size , ...]
457
+ y_pred_batch = y_pred [i * batch_size : ( i + 1 ) * batch_size , ...]
458
+ return y_pred_batch , y_true_batch
463
459
464
460
evaluator = Engine (update_fn )
465
461
@@ -478,13 +474,19 @@ def Fbeta(r, p, beta):
478
474
data = list (range (n_iters ))
479
475
state = evaluator .run (data , max_epochs = 1 )
480
476
477
+ y_pred = idist .all_gather (y_pred )
478
+ y_true = idist .all_gather (y_true )
479
+
481
480
assert "f1" in state .metrics
482
481
assert "ff1" in state .metrics
483
- f1_true = f1_score (y_true .ravel (), np .argmax (y_pred .reshape (- 1 , n_classes ), axis = - 1 ), average = "macro" )
482
+ f1_true = f1_score (
483
+ y_true .ravel ().cpu (), np .argmax (y_pred .reshape (- 1 , n_classes ).cpu (), axis = - 1 ), average = "macro"
484
+ )
484
485
assert f1_true == approx (state .metrics ["f1" ])
485
486
assert 1.0 + f1_true == approx (state .metrics ["ff1" ])
486
487
487
- for _ in range (3 ):
488
+ for i in range (3 ):
489
+ torch .manual_seed (12 + rank + i )
488
490
_test ("cpu" )
489
491
if device .type != "xla" :
490
492
_test (idist .device ())
@@ -493,28 +495,44 @@ def Fbeta(r, p, beta):
493
495
def _test_distrib_metrics_on_diff_devices (device ):
494
496
n_classes = 10
495
497
n_iters = 12
496
- s = 16
497
- offset = n_iters * s
498
+ batch_size = 16
498
499
rank = idist .get_rank ()
500
+ torch .manual_seed (12 + rank )
499
501
500
- y_true = torch .randint (0 , n_classes , size = (offset * idist . get_world_size () ,)).to (device )
501
- y_preds = torch .rand (offset * idist . get_world_size () , n_classes ).to (device )
502
+ y_true = torch .randint (0 , n_classes , size = (n_iters * batch_size ,)).to (device )
503
+ y_preds = torch .rand (n_iters * batch_size , n_classes ).to (device )
502
504
503
505
def update (engine , i ):
504
506
return (
505
- y_preds [i * s + rank * offset : (i + 1 ) * s + rank * offset ],
506
- y_true [i * s + rank * offset : (i + 1 ) * s + rank * offset ],
507
+ y_preds [i * batch_size : (i + 1 ) * batch_size , : ],
508
+ y_true [i * batch_size : (i + 1 ) * batch_size ],
507
509
)
508
510
511
+ evaluator = Engine (update )
512
+
509
513
precision = Precision (average = False , device = "cpu" )
510
514
recall = Recall (average = False , device = device )
511
- custom_metric = precision * recall
512
515
513
- engine = Engine (update )
514
- custom_metric .attach (engine , "custom_metric" )
516
+ def Fbeta (r , p , beta ):
517
+ return torch .mean ((1 + beta ** 2 ) * p * r / (beta ** 2 * p + r )).item ()
518
+
519
+ F1 = MetricsLambda (Fbeta , recall , precision , 1 )
520
+ F1 .attach (evaluator , "f1" )
521
+
522
+ another_f1 = (1.0 + precision * recall * 2 / (precision + recall + 1e-20 )).mean ().item ()
523
+ another_f1 .attach (evaluator , "ff1" )
515
524
516
525
data = list (range (n_iters ))
517
- engine .run (data , max_epochs = 2 )
526
+ state = evaluator .run (data , max_epochs = 1 )
527
+
528
+ y_preds = idist .all_gather (y_preds )
529
+ y_true = idist .all_gather (y_true )
530
+
531
+ assert "f1" in state .metrics
532
+ assert "ff1" in state .metrics
533
+ f1_true = f1_score (y_true .ravel (), np .argmax (y_preds .reshape (- 1 , n_classes ), axis = - 1 ), average = "macro" )
534
+ assert f1_true == approx (state .metrics ["f1" ])
535
+ assert 1.0 + f1_true == approx (state .metrics ["ff1" ])
518
536
519
537
520
538
@pytest .mark .distributed
0 commit comments