@@ -798,7 +798,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
798798
799799 # create dataset composed from 2 structures which are repeated 5 times
800800 shape = 10
801- n_samples = 10
801+ n_samples = 2
802802 n_atoms = 2
803803 projection = 'nonnegative_symmetric'
804804 X1 , y1 = ot .datasets .make_data_classif ('3gauss' , shape , random_state = 42 )
@@ -827,6 +827,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
827827 use_adam_optimizer = True
828828 verbose = False
829829 tol = 10 ** (- 5 )
830+ epochs = 1
830831
831832 initial_total_reconstruction = 0
832833 for i in range (n_samples ):
@@ -839,7 +840,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
839840 # > Learn the dictionary using this init
840841 Cdict , log = ot .gromov .gromov_wasserstein_dictionary_learning (
841842 Cs , D = n_atoms , nt = shape , ps = ps , q = q , Cdict_init = Cdict_init ,
842- epochs = 5 , batch_size = 2 * n_samples , learning_rate = 1. , reg = 0. ,
843+ epochs = epochs , batch_size = 2 * n_samples , learning_rate = 1. , reg = 0. ,
843844 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
844845 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
845846 )
@@ -855,17 +856,10 @@ def test_gromov_wasserstein_dictionary_learning(nx):
855856 np .testing .assert_array_less (total_reconstruction , initial_total_reconstruction )
856857
857858 # Test: Perform same experiments after going through backend
858- initial_total_reconstruction_b = 0
859- for i in range (n_samples ):
860- _ , _ , _ , reconstruction = ot .gromov .gromov_wasserstein_linear_unmixing (
861- Csb [i ], Cdict_initb , p = psb [i ], q = qb , reg = 0. ,
862- tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200
863- )
864- initial_total_reconstruction_b += reconstruction
865859
866860 Cdictb , log = ot .gromov .gromov_wasserstein_dictionary_learning (
867861 Csb , D = n_atoms , nt = shape , ps = None , q = None , Cdict_init = Cdict_initb ,
868- epochs = 5 , batch_size = n_samples , learning_rate = 1. , reg = 0. ,
862+ epochs = epochs , batch_size = n_samples , learning_rate = 1. , reg = 0. ,
869863 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
870864 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
871865 )
@@ -878,7 +872,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
878872 )
879873 total_reconstruction_b += reconstruction
880874
881- np .testing .assert_array_less (total_reconstruction_b , initial_total_reconstruction_b )
875+ np .testing .assert_array_less (total_reconstruction_b , initial_total_reconstruction )
882876 np .testing .assert_allclose (total_reconstruction_b , total_reconstruction , atol = 1e-05 )
883877 np .testing .assert_allclose (total_reconstruction_b , total_reconstruction , atol = 1e-05 )
884878 np .testing .assert_allclose (Cdict , nx .to_numpy (Cdictb ), atol = 1e-03 )
@@ -888,7 +882,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
888882 np .random .seed (0 )
889883 Cdict_bis , log = ot .gromov .gromov_wasserstein_dictionary_learning (
890884 Cs , D = n_atoms , nt = shape , ps = None , q = None , Cdict_init = None ,
891- epochs = 5 , batch_size = n_samples , learning_rate = 1. , reg = 0. ,
885+ epochs = epochs , batch_size = n_samples , learning_rate = 1. , reg = 0. ,
892886 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
893887 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
894888 )
@@ -907,7 +901,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
907901 np .random .seed (0 )
908902 Cdictb_bis , log = ot .gromov .gromov_wasserstein_dictionary_learning (
909903 Csb , D = n_atoms , nt = shape , ps = psb , q = qb , Cdict_init = None ,
910- epochs = 5 , batch_size = n_samples , learning_rate = 1. , reg = 0. ,
904+ epochs = epochs , batch_size = n_samples , learning_rate = 1. , reg = 0. ,
911905 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
912906 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
913907 )
@@ -925,14 +919,15 @@ def test_gromov_wasserstein_dictionary_learning(nx):
925919
926920 # Test: Perform same comparison without providing the initial dictionary being an optional input
927921 # and testing other optimization settings untested until now.
922+ # We pass previously estimated dictionaries to speed up the process.
928923 use_adam_optimizer = False
929924 verbose = True
930925 use_log = True
931926
932927 np .random .seed (0 )
933928 Cdict_bis2 , log = ot .gromov .gromov_wasserstein_dictionary_learning (
934- Cs , D = n_atoms , nt = shape , ps = ps , q = q , Cdict_init = None ,
935- epochs = 5 , batch_size = n_samples , learning_rate = 10. , reg = 0. ,
929+ Cs , D = n_atoms , nt = shape , ps = ps , q = q , Cdict_init = Cdict ,
930+ epochs = epochs , batch_size = n_samples , learning_rate = 10. , reg = 0. ,
936931 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
937932 projection = projection , use_log = use_log , use_adam_optimizer = use_adam_optimizer , verbose = verbose
938933 )
@@ -945,13 +940,13 @@ def test_gromov_wasserstein_dictionary_learning(nx):
945940 )
946941 total_reconstruction_bis2 += reconstruction
947942
948- np .testing .assert_array_less (total_reconstruction_bis2 , initial_total_reconstruction )
943+ np .testing .assert_array_less (total_reconstruction_bis2 , total_reconstruction )
949944
950945 # Test: Same after going through backend
951946 np .random .seed (0 )
952947 Cdictb_bis2 , log = ot .gromov .gromov_wasserstein_dictionary_learning (
953- Csb , D = n_atoms , nt = shape , ps = psb , q = qb , Cdict_init = None ,
954- epochs = 5 , batch_size = n_samples , learning_rate = 10. , reg = 0. ,
948+ Csb , D = n_atoms , nt = shape , ps = psb , q = qb , Cdict_init = Cdictb ,
949+ epochs = epochs , batch_size = n_samples , learning_rate = 10. , reg = 0. ,
955950 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
956951 projection = projection , use_log = use_log , use_adam_optimizer = use_adam_optimizer , verbose = verbose
957952 )
@@ -1063,7 +1058,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
10631058
10641059 # create dataset composed from 2 structures which are repeated 5 times
10651060 shape = 10
1066- n_samples = 10
1061+ n_samples = 2
10671062 n_atoms = 2
10681063 projection = 'nonnegative_symmetric'
10691064 X1 , y1 = ot .datasets .make_data_classif ('3gauss' , shape , random_state = 42 )
@@ -1100,6 +1095,8 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11001095 use_adam_optimizer = True
11011096 verbose = False
11021097 tol = 1e-05
1098+ epochs = 1
1099+
11031100 initial_total_reconstruction = 0
11041101 for i in range (n_samples ):
11051102 _ , _ , _ , _ , reconstruction = ot .gromov .fused_gromov_wasserstein_linear_unmixing (
@@ -1112,7 +1109,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11121109 # on the learned dictionary is lower than the one using its initialization.
11131110 Cdict , Ydict , log = ot .gromov .fused_gromov_wasserstein_dictionary_learning (
11141111 Cs , Ys , D = n_atoms , nt = shape , ps = ps , q = q , Cdict_init = Cdict_init , Ydict_init = Ydict_init ,
1115- epochs = 5 , batch_size = n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
1112+ epochs = epochs , batch_size = n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
11161113 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
11171114 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
11181115 )
@@ -1128,17 +1125,10 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11281125 np .testing .assert_array_less (total_reconstruction , initial_total_reconstruction )
11291126
11301127 # Test: Perform same experiments after going through backend
1131- initial_total_reconstruction_b = 0
1132- for i in range (n_samples ):
1133- _ , _ , _ , _ , reconstruction = ot .gromov .fused_gromov_wasserstein_linear_unmixing (
1134- Csb [i ], Ysb [i ], Cdict_initb , Ydict_initb , p = psb [i ], q = qb , alpha = alpha , reg = 0. ,
1135- tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200
1136- )
1137- initial_total_reconstruction_b += reconstruction
11381128
11391129 Cdictb , Ydictb , log = ot .gromov .fused_gromov_wasserstein_dictionary_learning (
11401130 Csb , Ysb , D = n_atoms , nt = shape , ps = None , q = None , Cdict_init = Cdict_initb , Ydict_init = Ydict_initb ,
1141- epochs = 5 , batch_size = 2 * n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
1131+ epochs = epochs , batch_size = 2 * n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
11421132 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
11431133 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
11441134 )
@@ -1151,7 +1141,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11511141 )
11521142 total_reconstruction_b += reconstruction
11531143
1154- np .testing .assert_array_less (total_reconstruction_b , initial_total_reconstruction_b )
1144+ np .testing .assert_array_less (total_reconstruction_b , initial_total_reconstruction )
11551145 np .testing .assert_allclose (total_reconstruction_b , total_reconstruction , atol = 1e-05 )
11561146 np .testing .assert_allclose (Cdict , nx .to_numpy (Cdictb ), atol = 1e-03 )
11571147 np .testing .assert_allclose (Ydict , nx .to_numpy (Ydictb ), atol = 1e-03 )
@@ -1160,7 +1150,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11601150 np .random .seed (0 )
11611151 Cdict_bis , Ydict_bis , log = ot .gromov .fused_gromov_wasserstein_dictionary_learning (
11621152 Cs , Ys , D = n_atoms , nt = shape , ps = None , q = None , Cdict_init = None , Ydict_init = None ,
1163- epochs = 5 , batch_size = n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
1153+ epochs = epochs , batch_size = n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
11641154 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
11651155 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
11661156 )
@@ -1179,7 +1169,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11791169 np .random .seed (0 )
11801170 Cdictb_bis , Ydictb_bis , log = ot .gromov .fused_gromov_wasserstein_dictionary_learning (
11811171 Csb , Ysb , D = n_atoms , nt = shape , ps = None , q = None , Cdict_init = None , Ydict_init = None ,
1182- epochs = 5 , batch_size = n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
1172+ epochs = epochs , batch_size = n_samples , learning_rate_C = 1. , learning_rate_Y = 1. , alpha = alpha , reg = 0. ,
11831173 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
11841174 projection = projection , use_log = False , use_adam_optimizer = use_adam_optimizer , verbose = verbose
11851175 )
@@ -1199,11 +1189,11 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
11991189 verbose = True
12001190 use_log = True
12011191
1202- # > Perform similar experiment without providing the initial dictionary being an optional input
1192+ # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init.
12031193 np .random .seed (0 )
12041194 Cdict_bis2 , Ydict_bis2 , log = ot .gromov .fused_gromov_wasserstein_dictionary_learning (
1205- Cs , Ys , D = n_atoms , nt = shape , ps = ps , q = q , Cdict_init = None , Ydict_init = None ,
1206- epochs = 5 , batch_size = n_samples , learning_rate_C = 10. , learning_rate_Y = 10. , alpha = alpha , reg = 0. ,
1195+ Cs , Ys , D = n_atoms , nt = shape , ps = ps , q = q , Cdict_init = Cdict , Ydict_init = Ydict ,
1196+ epochs = epochs , batch_size = n_samples , learning_rate_C = 10. , learning_rate_Y = 10. , alpha = alpha , reg = 0. ,
12071197 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
12081198 projection = projection , use_log = use_log , use_adam_optimizer = use_adam_optimizer , verbose = verbose
12091199 )
@@ -1216,13 +1206,13 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
12161206 )
12171207 total_reconstruction_bis2 += reconstruction
12181208
1219- np .testing .assert_array_less (total_reconstruction_bis2 , initial_total_reconstruction )
1209+ np .testing .assert_array_less (total_reconstruction_bis2 , total_reconstruction )
12201210
12211211 # > Same after going through backend
12221212 np .random .seed (0 )
12231213 Cdictb_bis2 , Ydictb_bis2 , log = ot .gromov .fused_gromov_wasserstein_dictionary_learning (
1224- Csb , Ysb , D = n_atoms , nt = shape , ps = None , q = None , Cdict_init = None , Ydict_init = None ,
1225- epochs = 5 , batch_size = n_samples , learning_rate_C = 10. , learning_rate_Y = 10. , alpha = alpha , reg = 0. ,
1214+ Csb , Ysb , D = n_atoms , nt = shape , ps = None , q = None , Cdict_init = Cdictb , Ydict_init = Ydictb ,
1215+ epochs = epochs , batch_size = n_samples , learning_rate_C = 10. , learning_rate_Y = 10. , alpha = alpha , reg = 0. ,
12261216 tol_outer = tol , tol_inner = tol , max_iter_outer = 20 , max_iter_inner = 200 ,
12271217 projection = projection , use_log = use_log , use_adam_optimizer = use_adam_optimizer , verbose = verbose
12281218 )
0 commit comments