Skip to content

Commit aecd04a

Browse files
make GDL tests faster
1 parent 278a1aa commit aecd04a

File tree

1 file changed

+27
-37
lines changed

1 file changed

+27
-37
lines changed

test/test_gromov.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
798798

799799
# create dataset composed from 2 structures which are repeated 5 times
800800
shape = 10
801-
n_samples = 10
801+
n_samples = 2
802802
n_atoms = 2
803803
projection = 'nonnegative_symmetric'
804804
X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
@@ -827,6 +827,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
827827
use_adam_optimizer = True
828828
verbose = False
829829
tol = 10**(-5)
830+
epochs = 1
830831

831832
initial_total_reconstruction = 0
832833
for i in range(n_samples):
@@ -839,7 +840,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
839840
# > Learn the dictionary using this init
840841
Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning(
841842
Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init,
842-
epochs=5, batch_size=2 * n_samples, learning_rate=1., reg=0.,
843+
epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0.,
843844
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
844845
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
845846
)
@@ -855,17 +856,10 @@ def test_gromov_wasserstein_dictionary_learning(nx):
855856
np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
856857

857858
# Test: Perform same experiments after going through backend
858-
initial_total_reconstruction_b = 0
859-
for i in range(n_samples):
860-
_, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing(
861-
Csb[i], Cdict_initb, p=psb[i], q=qb, reg=0.,
862-
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
863-
)
864-
initial_total_reconstruction_b += reconstruction
865859

866860
Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning(
867861
Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb,
868-
epochs=5, batch_size=n_samples, learning_rate=1., reg=0.,
862+
epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
869863
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
870864
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
871865
)
@@ -878,7 +872,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
878872
)
879873
total_reconstruction_b += reconstruction
880874

881-
np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction_b)
875+
np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
882876
np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
883877
np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
884878
np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
@@ -888,7 +882,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
888882
np.random.seed(0)
889883
Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
890884
Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None,
891-
epochs=5, batch_size=n_samples, learning_rate=1., reg=0.,
885+
epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
892886
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
893887
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
894888
)
@@ -907,7 +901,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
907901
np.random.seed(0)
908902
Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning(
909903
Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None,
910-
epochs=5, batch_size=n_samples, learning_rate=1., reg=0.,
904+
epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0.,
911905
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
912906
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
913907
)
@@ -925,14 +919,15 @@ def test_gromov_wasserstein_dictionary_learning(nx):
925919

926920
# Test: Perform same comparison without providing the initial dictionary being an optional input
927921
# and testing other optimization settings untested until now.
922+
# We pass previously estimated dictionaries to speed up the process.
928923
use_adam_optimizer = False
929924
verbose = True
930925
use_log = True
931926

932927
np.random.seed(0)
933928
Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
934-
Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=None,
935-
epochs=5, batch_size=n_samples, learning_rate=10., reg=0.,
929+
Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict,
930+
epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
936931
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
937932
projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
938933
)
@@ -945,13 +940,13 @@ def test_gromov_wasserstein_dictionary_learning(nx):
945940
)
946941
total_reconstruction_bis2 += reconstruction
947942

948-
np.testing.assert_array_less(total_reconstruction_bis2, initial_total_reconstruction)
943+
np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
949944

950945
# Test: Same after going through backend
951946
np.random.seed(0)
952947
Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning(
953-
Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None,
954-
epochs=5, batch_size=n_samples, learning_rate=10., reg=0.,
948+
Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb,
949+
epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0.,
955950
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
956951
projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
957952
)
@@ -1063,7 +1058,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
10631058

10641059
# create dataset composed from 2 structures which are repeated 5 times
10651060
shape = 10
1066-
n_samples = 10
1061+
n_samples = 2
10671062
n_atoms = 2
10681063
projection = 'nonnegative_symmetric'
10691064
X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42)
@@ -1100,6 +1095,8 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11001095
use_adam_optimizer = True
11011096
verbose = False
11021097
tol = 1e-05
1098+
epochs = 1
1099+
11031100
initial_total_reconstruction = 0
11041101
for i in range(n_samples):
11051102
_, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
@@ -1112,7 +1109,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11121109
# on the learned dictionary is lower than the one using its initialization.
11131110
Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
11141111
Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init,
1115-
epochs=5, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
1112+
epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
11161113
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
11171114
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
11181115
)
@@ -1128,17 +1125,10 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11281125
np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction)
11291126

11301127
# Test: Perform same experiments after going through backend
1131-
initial_total_reconstruction_b = 0
1132-
for i in range(n_samples):
1133-
_, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing(
1134-
Csb[i], Ysb[i], Cdict_initb, Ydict_initb, p=psb[i], q=qb, alpha=alpha, reg=0.,
1135-
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
1136-
)
1137-
initial_total_reconstruction_b += reconstruction
11381128

11391129
Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
11401130
Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb,
1141-
epochs=5, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
1131+
epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
11421132
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
11431133
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
11441134
)
@@ -1151,7 +1141,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11511141
)
11521142
total_reconstruction_b += reconstruction
11531143

1154-
np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction_b)
1144+
np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
11551145
np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
11561146
np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
11571147
np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03)
@@ -1160,7 +1150,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11601150
np.random.seed(0)
11611151
Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
11621152
Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
1163-
epochs=5, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
1153+
epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
11641154
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
11651155
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
11661156
)
@@ -1179,7 +1169,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11791169
np.random.seed(0)
11801170
Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
11811171
Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
1182-
epochs=5, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
1172+
epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0.,
11831173
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
11841174
projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose
11851175
)
@@ -1199,11 +1189,11 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11991189
verbose = True
12001190
use_log = True
12011191

1202-
# > Perform similar experiment without providing the initial dictionary being an optional input
1192+
# > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init.
12031193
np.random.seed(0)
12041194
Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
1205-
Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=None, Ydict_init=None,
1206-
epochs=5, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
1195+
Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict,
1196+
epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
12071197
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
12081198
projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
12091199
)
@@ -1216,13 +1206,13 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
12161206
)
12171207
total_reconstruction_bis2 += reconstruction
12181208

1219-
np.testing.assert_array_less(total_reconstruction_bis2, initial_total_reconstruction)
1209+
np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction)
12201210

12211211
# > Same after going through backend
12221212
np.random.seed(0)
12231213
Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning(
1224-
Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None,
1225-
epochs=5, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
1214+
Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb,
1215+
epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0.,
12261216
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200,
12271217
projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose
12281218
)

0 commit comments

Comments
 (0)