Skip to content

Commit 1321339

Browse files
author
nmcguire101
committed
Fixed conflicts
2 parents c115bf9 + 0d40173 commit 1321339

File tree

3 files changed

+3
-24
lines changed

3 files changed

+3
-24
lines changed

ignite/metrics/metric.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from abc import ABCMeta, abstractmethod
32
from collections.abc import Mapping
43
from functools import wraps
@@ -210,17 +209,6 @@ def __init__(
210209
):
211210
self._output_transform = output_transform
212211

213-
# Check device if distributed is initialized:
214-
if idist.get_world_size() > 1:
215-
216-
# check if reset and update methods are decorated. Compute may not be decorated
217-
if not (hasattr(self.reset, "_decorated") and hasattr(self.update, "_decorated")):
218-
warnings.warn(
219-
f"{self.__class__.__name__} class does not support distributed setting. "
220-
"Computed result is not collected across all computing devices",
221-
RuntimeWarning,
222-
)
223-
224212
# Some metrics have a large performance regression when run on XLA devices, so for now, we disallow it.
225213
if torch.device(device).type == "xla":
226214
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")

tests/ignite/metrics/test_metric.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,6 @@ def update(self, output):
2929
assert output == self.true_output
3030

3131

32-
@pytest.mark.distributed
33-
@pytest.mark.skipif("WORLD_SIZE" not in os.environ, reason="Skip if WORLD_SIZE not in env vars")
34-
@pytest.mark.skipif(torch.cuda.is_available(), reason="Skip if GPU")
35-
def test_metric_warning(distributed_context_single_node_gloo):
36-
y = torch.tensor([1.0])
37-
with pytest.warns(RuntimeWarning, match=r"DummyMetric1 class does not support distributed setting"):
38-
DummyMetric1((y, y))
39-
40-
4132
def test_no_transform():
4233
y_pred = torch.Tensor([[2.0], [-2.0]])
4334
y = torch.zeros(2)

tests/ignite/metrics/test_root_mean_squared_error.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_zero_sample():
1818

1919

2020
@pytest.fixture(params=[0, 1, 2, 3])
21-
def generate_tests(request):
21+
def test_data(request):
2222
return [
2323
(torch.empty(10).uniform_(0, 10), torch.empty(10).uniform_(0, 10), 1),
2424
(torch.empty(10, 1).uniform_(-10, 10), torch.empty(10, 1).uniform_(-10, 10), 1),
@@ -29,11 +29,11 @@ def generate_tests(request):
2929

3030

3131
@pytest.mark.parametrize("n_times", range(3))
32-
def test_compute(n_times, generate_tests):
32+
def test_compute(n_times, test_data):
3333

3434
rmse = RootMeanSquaredError()
3535

36-
(y_pred, y, batch_size) = generate_tests
36+
y_pred, y, batch_size = test_data
3737
rmse.reset()
3838
if batch_size > 1:
3939
n_iters = y.shape[0] // batch_size + 1

0 commit comments

Comments
 (0)