Skip to content

Commit 96f717b

Browse files
puhuksadra-barikbinvfdev-5
authored
Update pairwise_dist, mse and lambda in test for generating data with different rank (#2670)
* Update pairwise_dist, mse and lambda in test for generating data with different rank Update `mean_pairwise_distance`, `mean_squared_error`, `metrics_lambda` * Update with review * Update test_metrics_lambda.py * Update test_metrics_lambda.py To add `assert` clauses * Update test_metrics_lambda.py * Update test_metrics_lambda.py * Update test_metrics_lambda.py * Update test_metrics_lambda.py * Update test_metrics_lambda.py * Update test_metrics_lambda.py * Update test_metrics_lambda.py * Update test_metrics_lambda.py Co-authored-by: Sadra Barikbin <sadraqazvin1@yahoo.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent b17ee25 commit 96f717b

File tree

3 files changed

+74
-46
lines changed

3 files changed

+74
-46
lines changed

tests/ignite/metrics/test_mean_pairwise_distance.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,22 @@ def _test_distrib_integration(device):
5656
from ignite.engine import Engine
5757

5858
rank = idist.get_rank()
59-
torch.manual_seed(12)
59+
torch.manual_seed(12 + rank)
6060

61-
n_iters = 100
62-
s = 50
63-
offset = n_iters * s
61+
def _test(metric_device):
6462

65-
y_true = torch.rand(offset * idist.get_world_size(), 10).to(device)
66-
y_preds = torch.rand(offset * idist.get_world_size(), 10).to(device)
63+
n_iters = 100
64+
batch_size = 50
6765

68-
def update(engine, i):
69-
return (
70-
y_preds[i * s + offset * rank : (i + 1) * s + offset * rank, ...],
71-
y_true[i * s + offset * rank : (i + 1) * s + offset * rank, ...],
72-
)
66+
y_true = torch.rand(n_iters * batch_size, 10).to(device)
67+
y_preds = torch.rand(n_iters * batch_size, 10).to(device)
68+
69+
def update(engine, i):
70+
return (
71+
y_preds[i * batch_size : (i + 1) * batch_size, ...],
72+
y_true[i * batch_size : (i + 1) * batch_size, ...],
73+
)
7374

74-
def _test(metric_device):
7575
engine = Engine(update)
7676

7777
m = MeanPairwiseDistance(device=metric_device)
@@ -80,14 +80,20 @@ def _test(metric_device):
8080
data = list(range(n_iters))
8181
engine.run(data=data, max_epochs=1)
8282

83+
y_preds = idist.all_gather(y_preds)
84+
y_true = idist.all_gather(y_true)
85+
8386
assert "mpwd" in engine.state.metrics
8487
res = engine.state.metrics["mpwd"]
8588

8689
true_res = []
8790
for i in range(n_iters * idist.get_world_size()):
8891
true_res.append(
8992
torch.pairwise_distance(
90-
y_true[i * s : (i + 1) * s, ...], y_preds[i * s : (i + 1) * s, ...], p=m._p, eps=m._eps
93+
y_true[i * batch_size : (i + 1) * batch_size, ...],
94+
y_preds[i * batch_size : (i + 1) * batch_size, ...],
95+
p=m._p,
96+
eps=m._eps,
9197
)
9298
.cpu()
9399
.numpy()

tests/ignite/metrics/test_mean_squared_error.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,21 @@ def _test_distrib_integration(device, tol=1e-6):
5858
from ignite.engine import Engine
5959

6060
rank = idist.get_rank()
61-
n_iters = 100
62-
s = 10
63-
offset = n_iters * s
61+
torch.manual_seed(12 + rank)
6462

65-
y_true = torch.arange(0, offset * idist.get_world_size(), dtype=torch.float).to(device)
66-
y_preds = torch.ones(offset * idist.get_world_size(), dtype=torch.float).to(device)
63+
def _test(metric_device):
64+
n_iters = 100
65+
batch_size = 10
6766

68-
def update(engine, i):
69-
return (
70-
y_preds[i * s + offset * rank : (i + 1) * s + offset * rank],
71-
y_true[i * s + offset * rank : (i + 1) * s + offset * rank],
72-
)
67+
y_true = torch.arange(0, n_iters * batch_size, dtype=torch.float).to(device)
68+
y_preds = torch.ones(n_iters * batch_size, dtype=torch.float).to(device)
69+
70+
def update(engine, i):
71+
return (
72+
y_preds[i * batch_size : (i + 1) * batch_size],
73+
y_true[i * batch_size : (i + 1) * batch_size],
74+
)
7375

74-
def _test(metric_device):
7576
engine = Engine(update)
7677

7778
m = MeanSquaredError(device=metric_device)
@@ -80,6 +81,9 @@ def _test(metric_device):
8081
data = list(range(n_iters))
8182
engine.run(data=data, max_epochs=1)
8283

84+
y_preds = idist.all_gather(y_preds)
85+
y_true = idist.all_gather(y_true)
86+
8387
assert "mse" in engine.state.metrics
8488
res = engine.state.metrics["mse"]
8589

tests/ignite/metrics/test_metrics_lambda.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -437,29 +437,25 @@ def compute_true_somemetric(y_pred, y):
437437
def _test_distrib_integration(device):
438438

439439
rank = idist.get_rank()
440-
np.random.seed(12)
441440

442441
n_iters = 10
443442
batch_size = 10
444443
n_classes = 10
445444

446445
def _test(metric_device):
447-
y_true = np.arange(0, n_iters * batch_size * idist.get_world_size(), dtype="int64") % n_classes
448-
y_pred = 0.2 * np.random.rand(n_iters * batch_size * idist.get_world_size(), n_classes)
449-
for i in range(n_iters * batch_size * idist.get_world_size()):
446+
y_true = torch.arange(0, n_iters * batch_size, dtype=torch.int64).to(device) % n_classes
447+
y_pred = 0.2 * torch.rand(n_iters * batch_size, n_classes).to(device)
448+
for i in range(n_iters * batch_size):
450449
if np.random.rand() > 0.4:
451450
y_pred[i, y_true[i]] = 1.0
452451
else:
453452
j = np.random.randint(0, n_classes)
454453
y_pred[i, j] = 0.7
455454

456-
y_true = y_true.reshape(n_iters * idist.get_world_size(), batch_size)
457-
y_pred = y_pred.reshape(n_iters * idist.get_world_size(), batch_size, n_classes)
458-
459455
def update_fn(engine, i):
460-
y_true_batch = y_true[i + rank * n_iters, ...]
461-
y_pred_batch = y_pred[i + rank * n_iters, ...]
462-
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
456+
y_true_batch = y_true[i * batch_size : (i + 1) * batch_size, ...]
457+
y_pred_batch = y_pred[i * batch_size : (i + 1) * batch_size, ...]
458+
return y_pred_batch, y_true_batch
463459

464460
evaluator = Engine(update_fn)
465461

@@ -478,13 +474,19 @@ def Fbeta(r, p, beta):
478474
data = list(range(n_iters))
479475
state = evaluator.run(data, max_epochs=1)
480476

477+
y_pred = idist.all_gather(y_pred)
478+
y_true = idist.all_gather(y_true)
479+
481480
assert "f1" in state.metrics
482481
assert "ff1" in state.metrics
483-
f1_true = f1_score(y_true.ravel(), np.argmax(y_pred.reshape(-1, n_classes), axis=-1), average="macro")
482+
f1_true = f1_score(
483+
y_true.ravel().cpu(), np.argmax(y_pred.reshape(-1, n_classes).cpu(), axis=-1), average="macro"
484+
)
484485
assert f1_true == approx(state.metrics["f1"])
485486
assert 1.0 + f1_true == approx(state.metrics["ff1"])
486487

487-
for _ in range(3):
488+
for i in range(3):
489+
torch.manual_seed(12 + rank + i)
488490
_test("cpu")
489491
if device.type != "xla":
490492
_test(idist.device())
@@ -493,28 +495,44 @@ def Fbeta(r, p, beta):
493495
def _test_distrib_metrics_on_diff_devices(device):
494496
n_classes = 10
495497
n_iters = 12
496-
s = 16
497-
offset = n_iters * s
498+
batch_size = 16
498499
rank = idist.get_rank()
500+
torch.manual_seed(12 + rank)
499501

500-
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
501-
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)
502+
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
503+
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)
502504

503505
def update(engine, i):
504506
return (
505-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset],
506-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
507+
y_preds[i * batch_size : (i + 1) * batch_size, :],
508+
y_true[i * batch_size : (i + 1) * batch_size],
507509
)
508510

511+
evaluator = Engine(update)
512+
509513
precision = Precision(average=False, device="cpu")
510514
recall = Recall(average=False, device=device)
511-
custom_metric = precision * recall
512515

513-
engine = Engine(update)
514-
custom_metric.attach(engine, "custom_metric")
516+
def Fbeta(r, p, beta):
517+
return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item()
518+
519+
F1 = MetricsLambda(Fbeta, recall, precision, 1)
520+
F1.attach(evaluator, "f1")
521+
522+
another_f1 = (1.0 + precision * recall * 2 / (precision + recall + 1e-20)).mean().item()
523+
another_f1.attach(evaluator, "ff1")
515524

516525
data = list(range(n_iters))
517-
engine.run(data, max_epochs=2)
526+
state = evaluator.run(data, max_epochs=1)
527+
528+
y_preds = idist.all_gather(y_preds)
529+
y_true = idist.all_gather(y_true)
530+
531+
assert "f1" in state.metrics
532+
assert "ff1" in state.metrics
533+
f1_true = f1_score(y_true.ravel(), np.argmax(y_preds.reshape(-1, n_classes), axis=-1), average="macro")
534+
assert f1_true == approx(state.metrics["f1"])
535+
assert 1.0 + f1_true == approx(state.metrics["ff1"])
518536

519537

520538
@pytest.mark.distributed

0 commit comments

Comments
 (0)