Skip to content

Commit 1afabbf

Browse files
authored
Merge branch 'master' into rewrite-engine-terminate-and-co
2 parents 53c9659 + c506c31 commit 1afabbf

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

ignite/metrics/gan/fid.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
]
1313

1414

15+
if Version(torch.__version__) <= Version("1.7.0"):
16+
torch_outer = torch.ger
17+
else:
18+
torch_outer = torch.outer
19+
20+
1521
def fid_score(
1622
mu1: torch.Tensor, mu2: torch.Tensor, sigma1: torch.Tensor, sigma2: torch.Tensor, eps: float = 1e-6
1723
) -> float:
@@ -193,22 +199,14 @@ def __init__(
193199
def _online_update(features: torch.Tensor, total: torch.Tensor, sigma: torch.Tensor) -> None:
194200

195201
total += features
196-
197-
if Version(torch.__version__) <= Version("1.7.0"):
198-
sigma += torch.ger(features, features)
199-
else:
200-
sigma += torch.outer(features, features)
202+
sigma += torch_outer(features, features)
201203

202204
def _get_covariance(self, sigma: torch.Tensor, total: torch.Tensor) -> torch.Tensor:
203205
r"""
204206
Calculates covariance from mean and sum of products of variables
205207
"""
206208

207-
if Version(torch.__version__) <= Version("1.7.0"):
208-
sub_matrix = torch.ger(total, total)
209-
else:
210-
sub_matrix = torch.outer(total, total)
211-
209+
sub_matrix = torch_outer(total, total)
212210
sub_matrix = sub_matrix / self._num_examples
213211

214212
return (sigma - sub_matrix) / (self._num_examples - 1)

tests/ignite/contrib/metrics/test_average_precision.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,9 @@ def get_test_cases():
215215
def _test_distrib_integration_binary_input(device):
216216

217217
rank = idist.get_rank()
218-
torch.manual_seed(12)
219218
n_iters = 80
220-
s = 16
219+
batch_size = 16
221220
n_classes = 2
222-
offset = n_iters * s
223221

224222
def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
225223
metric_device = torch.device(metric_device)
@@ -232,6 +230,9 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
232230
data = list(range(n_iters))
233231
engine.run(data=data, max_epochs=n_epochs)
234232

233+
y_true = idist.all_gather(y_true)
234+
y_preds = idist.all_gather(y_preds)
235+
235236
assert "ap" in engine.state.metrics
236237

237238
res = engine.state.metrics["ap"]
@@ -240,24 +241,25 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
240241
assert pytest.approx(res) == true_res
241242

242243
def get_tests(is_N):
244+
torch.manual_seed(12 + rank)
243245
if is_N:
244-
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
245-
y_preds = torch.rand(offset * idist.get_world_size()).to(device)
246+
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
247+
y_preds = torch.rand(n_iters * batch_size).to(device)
246248

247249
def update_fn(engine, i):
248250
return (
249-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset],
250-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
251+
y_preds[i * batch_size : (i + 1) * batch_size],
252+
y_true[i * batch_size : (i + 1) * batch_size],
251253
)
252254

253255
else:
254-
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(), 10)).to(device)
255-
y_preds = torch.randint(0, n_classes, size=(offset * idist.get_world_size(), 10)).to(device)
256+
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size, 10)).to(device)
257+
y_preds = torch.randint(0, n_classes, size=(n_iters * batch_size, 10)).to(device)
256258

257259
def update_fn(engine, i):
258260
return (
259-
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
260-
y_true[i * s + rank * offset : (i + 1) * s + rank * offset, :],
261+
y_preds[i * batch_size : (i + 1) * batch_size, :],
262+
y_true[i * batch_size : (i + 1) * batch_size, :],
261263
)
262264

263265
return y_preds, y_true, update_fn

0 commit comments

Comments
 (0)