Skip to content

Commit 0e431c2

Browse files
authored
[MRG] Add tests about type and GPU for emd/emd2 + 1d variants + wasserstein1d (#304)
* new test gpu * pep 8 of couse * debug torch * jax with gpu * device put * device put * it works * emd1d and emd2_1d working * emd_1d and emd2_1d done * cleanup * of course * should work on gpu now * tests done+ pep8
1 parent 2fe69eb commit 0e431c2

File tree

4 files changed

+146
-48
lines changed

4 files changed

+146
-48
lines changed

ot/backend.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class Backend():
102102

103103
__name__ = None
104104
__type__ = None
105+
__type_list__ = None
105106

106107
rng_ = None
107108

@@ -663,6 +664,8 @@ class NumpyBackend(Backend):
663664

664665
__name__ = 'numpy'
665666
__type__ = np.ndarray
667+
__type_list__ = [np.array(1, dtype=np.float32),
668+
np.array(1, dtype=np.float64)]
666669

667670
rng_ = np.random.RandomState()
668671

@@ -888,20 +891,25 @@ class JaxBackend(Backend):
888891

889892
__name__ = 'jax'
890893
__type__ = jax_type
894+
__type_list__ = None
891895

892896
rng_ = None
893897

894898
def __init__(self):
895899
self.rng_ = jax.random.PRNGKey(42)
896900

901+
for d in jax.devices():
902+
self.__type_list__ = [jax.device_put(jnp.array(1, dtype=np.float32), d),
903+
jax.device_put(jnp.array(1, dtype=np.float64), d)]
904+
897905
def to_numpy(self, a):
898906
return np.array(a)
899907

900908
def from_numpy(self, a, type_as=None):
901909
if type_as is None:
902910
return jnp.array(a)
903911
else:
904-
return jnp.array(a).astype(type_as.dtype)
912+
return jax.device_put(jnp.array(a).astype(type_as.dtype), type_as.device_buffer.device())
905913

906914
def set_gradients(self, val, inputs, grads):
907915
from jax.flatten_util import ravel_pytree
@@ -1130,6 +1138,7 @@ class TorchBackend(Backend):
11301138

11311139
__name__ = 'torch'
11321140
__type__ = torch_type
1141+
__type_list__ = None
11331142

11341143
rng_ = None
11351144

@@ -1138,6 +1147,13 @@ def __init__(self):
11381147
self.rng_ = torch.Generator()
11391148
self.rng_.seed()
11401149

1150+
self.__type_list__ = [torch.tensor(1, dtype=torch.float32),
1151+
torch.tensor(1, dtype=torch.float64)]
1152+
1153+
if torch.cuda.is_available():
1154+
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
1155+
self.__type_list__.append(torch.tensor(1, dtype=torch.float64, device='cuda'))
1156+
11411157
from torch.autograd import Function
11421158

11431159
# define a function that takes inputs val and grads
@@ -1160,6 +1176,8 @@ def to_numpy(self, a):
11601176
return a.cpu().detach().numpy()
11611177

11621178
def from_numpy(self, a, type_as=None):
1179+
if isinstance(a, float):
1180+
a = np.array(a)
11631181
if type_as is None:
11641182
return torch.from_numpy(a)
11651183
else:

ot/lp/solver_1d.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
235235

236236
# ensure that same mass
237237
np.testing.assert_almost_equal(
238-
nx.sum(a, axis=0),
239-
nx.sum(b, axis=0),
238+
nx.to_numpy(nx.sum(a, axis=0)),
239+
nx.to_numpy(nx.sum(b, axis=0)),
240240
err_msg='a and b vector must have the same sum'
241241
)
242242
b = b * nx.sum(a) / nx.sum(b)
@@ -247,10 +247,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
247247
perm_b = nx.argsort(x_b_1d)
248248

249249
G_sorted, indices, cost = emd_1d_sorted(
250-
nx.to_numpy(a[perm_a]),
251-
nx.to_numpy(b[perm_b]),
252-
nx.to_numpy(x_a_1d[perm_a]),
253-
nx.to_numpy(x_b_1d[perm_b]),
250+
nx.to_numpy(a[perm_a]).astype(np.float64),
251+
nx.to_numpy(b[perm_b]).astype(np.float64),
252+
nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
253+
nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
254254
metric=metric, p=p
255255
)
256256

@@ -266,7 +266,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
266266
elif str(nx) == "jax":
267267
warnings.warn("JAX does not support sparse matrices, converting to dense")
268268
if log:
269-
log = {'cost': cost}
269+
log = {'cost': nx.from_numpy(cost, type_as=x_a)}
270270
return G, log
271271
return G
272272

test/test_1d_solver.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,96 @@ def test_wasserstein_1d(nx):
8383
Xb = nx.from_numpy(X)
8484
res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2)
8585
np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4)
86+
87+
88+
@pytest.mark.parametrize('nx', backend_list)
89+
def test_wasserstein_1d_type_devices(nx):
90+
91+
rng = np.random.RandomState(0)
92+
93+
n = 10
94+
x = np.linspace(0, 5, n)
95+
rho_u = np.abs(rng.randn(n))
96+
rho_u /= rho_u.sum()
97+
rho_v = np.abs(rng.randn(n))
98+
rho_v /= rho_v.sum()
99+
100+
for tp in nx.__type_list__:
101+
102+
print(tp.dtype)
103+
104+
xb = nx.from_numpy(x, type_as=tp)
105+
rho_ub = nx.from_numpy(rho_u, type_as=tp)
106+
rho_vb = nx.from_numpy(rho_v, type_as=tp)
107+
108+
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
109+
110+
if not str(nx) == 'numpy':
111+
assert res.dtype == xb.dtype
112+
113+
114+
def test_emd_1d_emd2_1d():
115+
# test emd1d gives similar results as emd
116+
n = 20
117+
m = 30
118+
rng = np.random.RandomState(0)
119+
u = rng.randn(n, 1)
120+
v = rng.randn(m, 1)
121+
122+
M = ot.dist(u, v, metric='sqeuclidean')
123+
124+
G, log = ot.emd([], [], M, log=True)
125+
wass = log["cost"]
126+
G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
127+
wass1d = log["cost"]
128+
wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
129+
wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
130+
131+
# check loss is similar
132+
np.testing.assert_allclose(wass, wass1d)
133+
np.testing.assert_allclose(wass, wass1d_emd2)
134+
135+
# check loss is similar to scipy's implementation for Euclidean metric
136+
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
137+
np.testing.assert_allclose(wass_sp, wass1d_euc)
138+
139+
# check constraints
140+
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
141+
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
142+
143+
# check G is similar
144+
np.testing.assert_allclose(G, G_1d, atol=1e-15)
145+
146+
# check AssertionError is raised if called on non 1d arrays
147+
u = np.random.randn(n, 2)
148+
v = np.random.randn(m, 2)
149+
with pytest.raises(AssertionError):
150+
ot.emd_1d(u, v, [], [])
151+
152+
153+
def test_emd1d_type_devices(nx):
154+
155+
rng = np.random.RandomState(0)
156+
157+
n = 10
158+
x = np.linspace(0, 5, n)
159+
rho_u = np.abs(rng.randn(n))
160+
rho_u /= rho_u.sum()
161+
rho_v = np.abs(rng.randn(n))
162+
rho_v /= rho_v.sum()
163+
164+
for tp in nx.__type_list__:
165+
166+
print(tp.dtype)
167+
168+
xb = nx.from_numpy(x, type_as=tp)
169+
rho_ub = nx.from_numpy(rho_u, type_as=tp)
170+
rho_vb = nx.from_numpy(rho_v, type_as=tp)
171+
172+
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
173+
174+
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
175+
176+
assert emd.dtype == xb.dtype
177+
if not str(nx) == 'numpy':
178+
assert emd2.dtype == xb.dtype

test/test_ot.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import ot
1313
from ot.datasets import make_1D_gauss as gauss
1414
from ot.backend import torch
15-
from scipy.stats import wasserstein_distance
1615

1716

1817
def test_emd_dimension_and_mass_mismatch():
@@ -77,6 +76,33 @@ def test_emd2_backends(nx):
7776
np.allclose(val, nx.to_numpy(valb))
7877

7978

79+
def test_emd_emd2_types_devices(nx):
80+
n_samples = 100
81+
n_features = 2
82+
rng = np.random.RandomState(0)
83+
84+
x = rng.randn(n_samples, n_features)
85+
y = rng.randn(n_samples, n_features)
86+
a = ot.utils.unif(n_samples)
87+
88+
M = ot.dist(x, y)
89+
90+
for tp in nx.__type_list__:
91+
92+
print(tp.dtype)
93+
94+
ab = nx.from_numpy(a, type_as=tp)
95+
Mb = nx.from_numpy(M, type_as=tp)
96+
97+
Gb = ot.emd(ab, ab, Mb)
98+
99+
w = ot.emd2(ab, ab, Mb)
100+
101+
assert Gb.dtype == Mb.dtype
102+
if not str(nx) == 'numpy':
103+
assert w.dtype == Mb.dtype
104+
105+
80106
def test_emd2_gradients():
81107
n_samples = 100
82108
n_features = 2
@@ -126,45 +152,6 @@ def test_emd_emd2():
126152
np.testing.assert_allclose(w, 0)
127153

128154

129-
def test_emd_1d_emd2_1d():
130-
# test emd1d gives similar results as emd
131-
n = 20
132-
m = 30
133-
rng = np.random.RandomState(0)
134-
u = rng.randn(n, 1)
135-
v = rng.randn(m, 1)
136-
137-
M = ot.dist(u, v, metric='sqeuclidean')
138-
139-
G, log = ot.emd([], [], M, log=True)
140-
wass = log["cost"]
141-
G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
142-
wass1d = log["cost"]
143-
wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
144-
wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
145-
146-
# check loss is similar
147-
np.testing.assert_allclose(wass, wass1d)
148-
np.testing.assert_allclose(wass, wass1d_emd2)
149-
150-
# check loss is similar to scipy's implementation for Euclidean metric
151-
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
152-
np.testing.assert_allclose(wass_sp, wass1d_euc)
153-
154-
# check constraints
155-
np.testing.assert_allclose(np.ones((n,)) / n, G.sum(1))
156-
np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0))
157-
158-
# check G is similar
159-
np.testing.assert_allclose(G, G_1d, atol=1e-15)
160-
161-
# check AssertionError is raised if called on non 1d arrays
162-
u = np.random.randn(n, 2)
163-
v = np.random.randn(m, 2)
164-
with pytest.raises(AssertionError):
165-
ot.emd_1d(u, v, [], [])
166-
167-
168155
def test_emd_empty():
169156
# test emd and emd2 for simple identity
170157
n = 100

0 commit comments

Comments
 (0)