Skip to content

Commit 8ef3341

Browse files
authored
[MRG] Speedup tests (#262)
* speedup tests * add color to tests and timings * add test unbalanced * stupid missing -
1 parent 2dbeeda commit 8ef3341

File tree

8 files changed

+77
-44
lines changed

8 files changed

+77
-44
lines changed

.github/workflows/build_tests.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
pip install -e .
4141
- name: Run tests
4242
run: |
43-
python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
43+
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
4444
- name: Upload codecov
4545
run: |
4646
codecov
@@ -95,7 +95,7 @@ jobs:
9595
pip install -e .
9696
- name: Run tests
9797
run: |
98-
python -m pytest -v test/ ot/ --ignore ot/gpu/
98+
python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes
9999
100100
101101
macos:
@@ -122,7 +122,7 @@ jobs:
122122
pip install -e .
123123
- name: Run tests
124124
run: |
125-
python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
125+
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
126126
127127
128128
windows:
@@ -150,4 +150,4 @@ jobs:
150150
python -m pip install -e .
151151
- name: Run tests
152152
run: |
153-
python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot
153+
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ pep8 :
4545
flake8 examples/ ot/ test/
4646

4747
test : FORCE pep8
48-
$(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
48+
$(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/
4949

5050
pytest : FORCE
51-
$(PYTHON) -m pytest -v test/ --doctest-modules --ignore ot/gpu/
51+
$(PYTHON) -m pytest --durations=20 -v test/ --doctest-modules --ignore ot/gpu/
5252

5353
release :
5454
twine upload dist/*

test/test_bregman.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def test_unmix():
293293

294294
def test_empirical_sinkhorn():
295295
# test sinkhorn
296-
n = 100
296+
n = 10
297297
a = ot.unif(n)
298298
b = ot.unif(n)
299299

@@ -332,7 +332,7 @@ def test_empirical_sinkhorn():
332332

333333
def test_lazy_empirical_sinkhorn():
334334
# test sinkhorn
335-
n = 100
335+
n = 10
336336
a = ot.unif(n)
337337
b = ot.unif(n)
338338
numIterMax = 1000
@@ -342,7 +342,7 @@ def test_lazy_empirical_sinkhorn():
342342
M = ot.dist(X_s, X_t)
343343
M_m = ot.dist(X_s, X_t, metric='minkowski')
344344

345-
f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True)
345+
f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
346346
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
347347
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)
348348

@@ -458,6 +458,7 @@ def test_implemented_methods():
458458
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
459459

460460

461+
@pytest.mark.filterwarnings("ignore:Bottleneck")
461462
def test_screenkhorn():
462463
# test screenkhorn
463464
rng = np.random.RandomState(0)

test/test_da.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def test_sinkhorn_l1l2_transport_class():
106106
"""test_sinkhorn_transport
107107
"""
108108

109-
ns = 150
110-
nt = 200
109+
ns = 50
110+
nt = 100
111111

112112
Xs, ys = make_data_classif('3gauss', ns)
113113
Xt, yt = make_data_classif('3gauss2', nt)
@@ -448,8 +448,8 @@ def test_mapping_transport_class():
448448
"""test_mapping_transport
449449
"""
450450

451-
ns = 60
452-
nt = 120
451+
ns = 20
452+
nt = 30
453453

454454
Xs, ys = make_data_classif('3gauss', ns)
455455
Xt, yt = make_data_classif('3gauss2', nt)

test/test_gromov.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import numpy as np
1010
import ot
1111

12+
import pytest
13+
1214

1315
def test_gromov():
1416
n_samples = 50 # nb samples
@@ -128,29 +130,30 @@ def test_gromov_barycenter():
128130
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
129131

130132

133+
@pytest.mark.filterwarnings("ignore:divide")
131134
def test_gromov_entropic_barycenter():
132-
ns = 50
133-
nt = 60
135+
ns = 20
136+
nt = 30
134137

135138
Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42)
136139
Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42)
137140

138141
C1 = ot.dist(Xs)
139142
C2 = ot.dist(Xt)
140143

141-
n_samples = 3
144+
n_samples = 2
142145
Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
143146
[ot.unif(ns), ot.unif(nt)
144147
], ot.unif(n_samples), [.5, .5],
145-
'square_loss', 2e-3,
146-
max_iter=100, tol=1e-3,
148+
'square_loss', 1e-3,
149+
max_iter=50, tol=1e-5,
147150
verbose=True)
148151
np.testing.assert_allclose(Cb.shape, (n_samples, n_samples))
149152

150153
Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2],
151154
[ot.unif(ns), ot.unif(nt)
152155
], ot.unif(n_samples), [.5, .5],
153-
'kl_loss', 2e-3,
156+
'kl_loss', 1e-3,
154157
max_iter=100, tol=1e-3)
155158
np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples))
156159

test/test_optim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def df(G):
3737
np.testing.assert_allclose(b, G.sum(0))
3838

3939

40-
def test_conditional_gradient2():
41-
n = 1000 # nb samples
40+
def test_conditional_gradient_itermax():
41+
n = 100 # nb samples
4242

4343
mu_s = np.array([0, 0])
4444
cov_s = np.array([[1, 0], [0, 1]])
@@ -63,7 +63,7 @@ def df(G):
6363

6464
reg = 1e-1
6565

66-
G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
66+
G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=10000,
6767
verbose=True, log=True)
6868

6969
np.testing.assert_allclose(a, G.sum(1))

test/test_stochastic.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def test_stochastic_sag():
3232
# test sag
33-
n = 15
33+
n = 10
3434
reg = 1
3535
numItermax = 30000
3636
rng = np.random.RandomState(0)
@@ -45,9 +45,9 @@ def test_stochastic_sag():
4545

4646
# check constratints
4747
np.testing.assert_allclose(
48-
u, G.sum(1), atol=1e-04) # cf convergence sag
48+
u, G.sum(1), atol=1e-03) # cf convergence sag
4949
np.testing.assert_allclose(
50-
u, G.sum(0), atol=1e-04) # cf convergence sag
50+
u, G.sum(0), atol=1e-03) # cf convergence sag
5151

5252

5353
#############################################################################
@@ -60,9 +60,9 @@ def test_stochastic_sag():
6060

6161
def test_stochastic_asgd():
6262
# test asgd
63-
n = 15
63+
n = 10
6464
reg = 1
65-
numItermax = 100000
65+
numItermax = 10000
6666
rng = np.random.RandomState(0)
6767

6868
x = rng.randn(n, 2)
@@ -75,9 +75,9 @@ def test_stochastic_asgd():
7575

7676
# check constratints
7777
np.testing.assert_allclose(
78-
u, G.sum(1), atol=1e-03) # cf convergence asgd
78+
u, G.sum(1), atol=1e-02) # cf convergence asgd
7979
np.testing.assert_allclose(
80-
u, G.sum(0), atol=1e-03) # cf convergence asgd
80+
u, G.sum(0), atol=1e-02) # cf convergence asgd
8181

8282

8383
#############################################################################
@@ -90,9 +90,9 @@ def test_stochastic_asgd():
9090

9191
def test_sag_asgd_sinkhorn():
9292
# test all algorithms
93-
n = 15
93+
n = 10
9494
reg = 1
95-
nb_iter = 100000
95+
nb_iter = 10000
9696
rng = np.random.RandomState(0)
9797

9898
x = rng.randn(n, 2)
@@ -107,17 +107,17 @@ def test_sag_asgd_sinkhorn():
107107

108108
# check constratints
109109
np.testing.assert_allclose(
110-
G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
110+
G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-02)
111111
np.testing.assert_allclose(
112-
G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
112+
G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-02)
113113
np.testing.assert_allclose(
114-
G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
114+
G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
115115
np.testing.assert_allclose(
116-
G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
116+
G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
117117
np.testing.assert_allclose(
118-
G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
118+
G_sag, G_sinkhorn, atol=1e-02) # cf convergence sag
119119
np.testing.assert_allclose(
120-
G_asgd, G_sinkhorn, atol=1e-03) # cf convergence asgd
120+
G_asgd, G_sinkhorn, atol=1e-02) # cf convergence asgd
121121

122122

123123
#############################################################################
@@ -136,7 +136,7 @@ def test_stochastic_dual_sgd():
136136
# test sgd
137137
n = 10
138138
reg = 1
139-
numItermax = 15000
139+
numItermax = 5000
140140
batch_size = 10
141141
rng = np.random.RandomState(0)
142142

@@ -167,7 +167,7 @@ def test_dual_sgd_sinkhorn():
167167
# test all dual algorithms
168168
n = 10
169169
reg = 1
170-
nb_iter = 15000
170+
nb_iter = 5000
171171
batch_size = 10
172172
rng = np.random.RandomState(0)
173173

@@ -183,11 +183,11 @@ def test_dual_sgd_sinkhorn():
183183

184184
# check constratints
185185
np.testing.assert_allclose(
186-
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
186+
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-02)
187187
np.testing.assert_allclose(
188-
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
188+
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-02)
189189
np.testing.assert_allclose(
190-
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
190+
G_sgd, G_sinkhorn, atol=1e-02) # cf convergence sgd
191191

192192
# Test gaussian
193193
n = 30

test/test_unbalanced.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_stabilized_vs_sinkhorn():
115115
G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
116116
method="sinkhorn_stabilized",
117117
reg_m=reg_m,
118-
log=True)
118+
log=True,
119+
verbose=True)
119120
G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
120121
method="sinkhorn", log=True)
121122

@@ -138,7 +139,7 @@ def test_unbalanced_barycenter(method):
138139
reg_m = 1.
139140

140141
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
141-
method=method, log=True)
142+
method=method, log=True, verbose=True)
142143
# check fixed point equations
143144
fi = reg_m / (reg_m + epsilon)
144145
logA = np.log(A + 1e-16)
@@ -173,6 +174,7 @@ def test_barycenter_stabilized_vs_sinkhorn():
173174
reg_m=reg_m, log=True,
174175
tau=100,
175176
method="sinkhorn_stabilized",
177+
verbose=True
176178
)
177179
q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
178180
method="sinkhorn",
@@ -182,6 +184,33 @@ def test_barycenter_stabilized_vs_sinkhorn():
182184
q, qstable, atol=1e-05)
183185

184186

187+
def test_wrong_method():
188+
189+
n = 10
190+
rng = np.random.RandomState(42)
191+
192+
x = rng.randn(n, 2)
193+
a = ot.utils.unif(n)
194+
195+
# make dists unbalanced
196+
b = ot.utils.unif(n) * 1.5
197+
198+
M = ot.dist(x, x)
199+
epsilon = 1.
200+
reg_m = 1.
201+
202+
with pytest.raises(ValueError):
203+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
204+
reg_m=reg_m,
205+
method='badmethod',
206+
log=True,
207+
verbose=True)
208+
with pytest.raises(ValueError):
209+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
210+
method='badmethod',
211+
verbose=True)
212+
213+
185214
def test_implemented_methods():
186215
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
187216
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']

0 commit comments

Comments
 (0)