Skip to content

Commit 3a0c9b8

Browse files
puhukvfdev-5
andauthored
Update median (#2681)
* Update median * Update _base.py * Update _base.py * Update _base.py * Update _base.py * Update _base.py * Update method name to `_torch_median` * Update test__base.py * Update test__base.py * Update test__base.py * Update test__base.py * Update `get_rank` position * Update test__base.py * Update test_median_absolute_percentage_error.py * Update test_median_relative_absolute_error.py * Update test__base.py Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 5acbaa3 commit 3a0c9b8

File tree

8 files changed

+121
-78
lines changed

8 files changed

+121
-78
lines changed

ignite/contrib/metrics/regression/_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def _check_output_types(output: Tuple[torch.Tensor, torch.Tensor]) -> None:
3030
raise TypeError(f"Input y dtype should be float 16, 32 or 64, but given {y.dtype}")
3131

3232

33+
def _torch_median(output: torch.Tensor) -> float:
34+
output = output.view(-1)
35+
len_ = len(output)
36+
37+
if len_ % 2 == 0:
38+
return float((torch.kthvalue(output, len_ // 2)[0] + torch.kthvalue(output, len_ // 2 + 1)[0]) / 2)
39+
else:
40+
return float(torch.kthvalue(output, len_ // 2 + 1)[0])
41+
42+
3343
class _BaseRegression(Metric):
3444
# Base class for all regression metrics
3545
# `update` method check the shapes and call internal overloaded

ignite/contrib/metrics/regression/median_absolute_error.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
import torch
44

5+
from ignite.contrib.metrics.regression._base import _torch_median
6+
57
from ignite.metrics import EpochMetric
68

79

810
def median_absolute_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor) -> float:
911
e = torch.abs(y.view_as(y_pred) - y_pred)
10-
return torch.median(e).item()
12+
return _torch_median(e)
1113

1214

1315
class MedianAbsoluteError(EpochMetric):

ignite/contrib/metrics/regression/median_absolute_percentage_error.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
import torch
44

5+
from ignite.contrib.metrics.regression._base import _torch_median
6+
57
from ignite.metrics import EpochMetric
68

79

810
def median_absolute_percentage_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor) -> float:
911
e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred))
10-
return 100.0 * torch.median(e).item()
12+
return 100.0 * _torch_median(e)
1113

1214

1315
class MedianAbsolutePercentageError(EpochMetric):

ignite/contrib/metrics/regression/median_relative_absolute_error.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
import torch
44

5+
from ignite.contrib.metrics.regression._base import _torch_median
6+
57
from ignite.metrics import EpochMetric
68

79

810
def median_relative_absolute_error_compute_fn(y_pred: torch.Tensor, y: torch.Tensor) -> float:
911
e = torch.abs(y.view_as(y_pred) - y_pred) / torch.abs(y.view_as(y_pred) - torch.mean(y))
10-
return torch.median(e).item()
12+
return _torch_median(e)
1113

1214

1315
class MedianRelativeAbsoluteError(EpochMetric):

tests/ignite/contrib/metrics/regression/test__base.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
15
import pytest
26
import torch
37

4-
from ignite.contrib.metrics.regression._base import _BaseRegression
8+
import ignite.distributed as idist
9+
10+
from ignite.contrib.metrics.regression._base import _BaseRegression, _torch_median
511

612

713
def test_base_regression_shapes():
@@ -37,3 +43,57 @@ def compute(self):
3743
with pytest.raises(TypeError, match=r"Input y dtype should be float"):
3844
y = torch.tensor([1, 1])
3945
m.update((y.float(), y))
46+
47+
48+
@pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)])
49+
def test_torch_median_numpy(size, device: Optional[str] = None):
50+
data = torch.rand(size).to(device)
51+
assert _torch_median(data) == np.median(data.cpu().numpy())
52+
53+
54+
@pytest.mark.parametrize("size", [101, (31, 3)])
55+
def test_torch_median_quantile(size, device: Optional[str] = None):
56+
data = torch.rand(size).to(device)
57+
assert _torch_median(data) == torch.quantile(data, 0.5, interpolation="midpoint")
58+
59+
size = 101
60+
data = torch.rand(size=(size,))
61+
assert _torch_median(data) == torch.median(data)
62+
63+
64+
@pytest.mark.tpu
65+
@pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)])
66+
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
67+
def test_on_even_size_xla(size):
68+
device = "xla"
69+
test_torch_median_numpy(size, device=device)
70+
71+
72+
@pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)])
73+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
74+
def test_on_even_size_gpu(size):
75+
test_torch_median_numpy(size, device="cuda")
76+
77+
78+
@pytest.mark.parametrize("size", [100, 101, (30, 3), (31, 3)])
79+
def test_create_even_size_cpu(size):
80+
test_torch_median_numpy(size, device="cpu")
81+
82+
83+
@pytest.mark.tpu
84+
@pytest.mark.parametrize("size", [101, (31, 3)])
85+
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
86+
def test_on_odd_size_xla(size):
87+
device = "xla"
88+
test_torch_median_quantile(size, device=device)
89+
90+
91+
@pytest.mark.parametrize("size", [101, (31, 3)])
92+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
93+
def test_on_odd_size_gpu(size):
94+
test_torch_median_quantile(size, device="cuda")
95+
96+
97+
@pytest.mark.parametrize("size", [101, (31, 3)])
98+
def test_create_odd_size_cpu(size):
99+
test_torch_median_quantile(size, device="cpu")

tests/ignite/contrib/metrics/regression/test_median_absolute_error.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,9 @@ def update_fn(engine, batch):
108108

109109

110110
def _test_distrib_compute(device):
111-
rank = idist.get_rank()
112-
113111
def _test(metric_device):
114112
metric_device = torch.device(metric_device)
115113
m = MedianAbsoluteError(device=metric_device)
116-
torch.manual_seed(10 + rank)
117114

118115
size = 105
119116

@@ -132,39 +129,29 @@ def _test(metric_device):
132129

133130
e = np.abs(np_y - np_y_pred)
134131

135-
# The results between numpy.median() and torch.median() are Inconsistant
136-
# when the length of the array/tensor is even. So this is a hack to avoid that.
137-
# issue: https://github.com/pytorch/pytorch/issues/1837
138-
if np_y_pred.shape[0] % 2 == 0:
139-
e_prepend = np.insert(e, 0, e[0], axis=0)
140-
np_res_prepend = np.median(e_prepend)
141-
assert pytest.approx(res) == np_res_prepend
142-
else:
143-
np_res = np.median(e)
144-
assert pytest.approx(res) == np_res
145-
146-
for _ in range(3):
132+
np_res = np.median(e)
133+
assert pytest.approx(res) == np_res
134+
135+
rank = idist.get_rank()
136+
for i in range(3):
137+
torch.manual_seed(10 + rank + i)
147138
_test("cpu")
148139
if device.type != "xla":
149140
_test(idist.device())
150141

151142

152143
def _test_distrib_integration(device):
153-
154-
rank = idist.get_rank()
155-
torch.manual_seed(12)
156-
157144
def _test(n_epochs, metric_device):
158145
metric_device = torch.device(metric_device)
159146
n_iters = 80
160-
size = 105
161-
y_true = torch.rand(size=(size,)).to(device)
162-
y_preds = torch.rand(size=(size,)).to(device)
147+
batch_size = 105
148+
y_true = torch.rand(size=(n_iters * batch_size,)).to(device)
149+
y_preds = torch.rand(size=(n_iters * batch_size,)).to(device)
163150

164151
def update(engine, i):
165152
return (
166-
y_preds[i * size : (i + 1) * size],
167-
y_true[i * size : (i + 1) * size],
153+
y_preds[i * batch_size : (i + 1) * batch_size],
154+
y_true[i * batch_size : (i + 1) * batch_size],
168155
)
169156

170157
engine = Engine(update)
@@ -175,6 +162,9 @@ def update(engine, i):
175162
data = list(range(n_iters))
176163
engine.run(data=data, max_epochs=n_epochs)
177164

165+
y_preds = idist.all_gather(y_preds)
166+
y_true = idist.all_gather(y_true)
167+
178168
assert "mae" in engine.state.metrics
179169

180170
res = engine.state.metrics["mae"]
@@ -191,7 +181,9 @@ def update(engine, i):
191181
if device.type != "xla":
192182
metric_devices.append(idist.device())
193183
for metric_device in metric_devices:
194-
for _ in range(2):
184+
rank = idist.get_rank()
185+
for i in range(2):
186+
torch.manual_seed(10 + rank + i)
195187
_test(n_epochs=1, metric_device=metric_device)
196188
_test(n_epochs=2, metric_device=metric_device)
197189

tests/ignite/contrib/metrics/regression/test_median_absolute_percentage_error.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,9 @@ def update_fn(engine, batch):
108108

109109

110110
def _test_distrib_compute(device):
111-
rank = idist.get_rank()
112-
113111
def _test(metric_device):
114112
metric_device = torch.device(metric_device)
115113
m = MedianAbsolutePercentageError(device=metric_device)
116-
torch.manual_seed(10 + rank)
117114

118115
size = 105
119116

@@ -133,34 +130,24 @@ def _test(metric_device):
133130

134131
e = np.abs(np_y - np_y_pred) / np.abs(np_y)
135132

136-
# The results between numpy.median() and torch.median() are Inconsistant
137-
# when the length of the array/tensor is even. So this is a hack to avoid that.
138-
# issue: https://github.com/pytorch/pytorch/issues/1837
139-
if np_y_pred.shape[0] % 2 == 0:
140-
e_prepend = np.insert(e, 0, e[0], axis=0)
141-
np_res_prepend = 100.0 * np.median(e_prepend)
142-
assert pytest.approx(res) == np_res_prepend
143-
else:
144-
np_res = 100.0 * np.median(e)
145-
assert pytest.approx(res) == np_res
146-
147-
for _ in range(3):
133+
np_res = 100.0 * np.median(e)
134+
assert pytest.approx(res) == np_res
135+
136+
rank = idist.get_rank()
137+
for i in range(3):
138+
torch.manual_seed(10 + rank + i)
148139
_test("cpu")
149140
if device.type != "xla":
150141
_test(idist.device())
151142

152143

153144
def _test_distrib_integration(device):
154-
155-
rank = idist.get_rank()
156-
torch.manual_seed(12)
157-
158145
def _test(n_epochs, metric_device):
159146
metric_device = torch.device(metric_device)
160147
n_iters = 80
161148
size = 105
162-
y_true = torch.rand(size=(size,)).to(device)
163-
y_preds = torch.rand(size=(size,)).to(device)
149+
y_true = torch.rand(size=(n_iters * size,)).to(device)
150+
y_preds = torch.rand(size=(n_iters * size,)).to(device)
164151

165152
def update(engine, i):
166153
return (
@@ -176,6 +163,9 @@ def update(engine, i):
176163
data = list(range(n_iters))
177164
engine.run(data=data, max_epochs=n_epochs)
178165

166+
y_preds = idist.all_gather(y_preds)
167+
y_true = idist.all_gather(y_true)
168+
179169
assert "mape" in engine.state.metrics
180170

181171
res = engine.state.metrics["mape"]
@@ -186,22 +176,15 @@ def update(engine, i):
186176
e = np.abs(np_y_true - np_y_preds) / np.abs(np_y_true)
187177
np_res = 100.0 * np.median(e)
188178

189-
e_prepend = np.insert(e, 0, e[0], axis=0)
190-
np_res_prepend = 100.0 * np.median(e_prepend)
191-
192-
# The results between numpy.median() and torch.median() are Inconsistant
193-
# when the length of the array/tensor is even. So this is a hack to avoid that.
194-
# issue: https://github.com/pytorch/pytorch/issues/1837
195-
if np_y_preds.shape[0] % 2 == 0:
196-
assert pytest.approx(res) == np_res_prepend
197-
else:
198-
assert pytest.approx(res) == np_res
179+
assert pytest.approx(res) == np_res
199180

200181
metric_devices = ["cpu"]
201182
if device.type != "xla":
202183
metric_devices.append(idist.device())
203184
for metric_device in metric_devices:
204-
for _ in range(2):
185+
rank = idist.get_rank()
186+
for i in range(2):
187+
torch.manual_seed(12 + rank + i)
205188
_test(n_epochs=1, metric_device=metric_device)
206189
_test(n_epochs=2, metric_device=metric_device)
207190

tests/ignite/contrib/metrics/regression/test_median_relative_absolute_error.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def update_fn(engine, batch):
108108

109109

110110
def _test_distrib_compute(device):
111-
rank = idist.get_rank()
112-
113111
def _test(metric_device):
114112
metric_device = torch.device(metric_device)
115113
m = MedianRelativeAbsoluteError(device=metric_device)
@@ -133,28 +131,17 @@ def _test(metric_device):
133131

134132
e = np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean())
135133

136-
# The results between numpy.median() and torch.median() are Inconsistant
137-
# when the length of the array/tensor is even. So this is a hack to avoid that.
138-
# issue: https://github.com/pytorch/pytorch/issues/1837
139-
if np_y_pred.shape[0] % 2 == 0:
140-
e_prepend = np.insert(e, 0, e[0], axis=0)
141-
np_res_prepend = np.median(e_prepend)
142-
assert pytest.approx(res) == np_res_prepend
143-
else:
144-
np_res = np.median(e)
145-
assert pytest.approx(res) == np_res
134+
np_res = np.median(e)
135+
assert pytest.approx(res) == np_res
146136

137+
rank = idist.get_rank()
147138
for _ in range(3):
148139
_test("cpu")
149140
if device.type != "xla":
150141
_test(idist.device())
151142

152143

153144
def _test_distrib_integration(device):
154-
155-
rank = idist.get_rank()
156-
torch.manual_seed(12)
157-
158145
def _test(n_epochs, metric_device):
159146
metric_device = torch.device(metric_device)
160147
n_iters = 80
@@ -176,6 +163,9 @@ def update(engine, i):
176163
data = list(range(n_iters))
177164
engine.run(data=data, max_epochs=n_epochs)
178165

166+
y_true = idist.all_gather(y_true)
167+
y_preds = idist.all_gather(y_preds)
168+
179169
assert "mare" in engine.state.metrics
180170

181171
res = engine.state.metrics["mare"]
@@ -192,7 +182,9 @@ def update(engine, i):
192182
if device.type != "xla":
193183
metric_devices.append(idist.device())
194184
for metric_device in metric_devices:
195-
for _ in range(2):
185+
rank = idist.get_rank()
186+
for i in range(2):
187+
torch.manual_seed(12 + rank + i)
196188
_test(n_epochs=1, metric_device=metric_device)
197189
_test(n_epochs=2, metric_device=metric_device)
198190

0 commit comments

Comments
 (0)