@@ -465,8 +465,8 @@ def test_partial_fgw2_gradients():
465465@pytest .skip_backend ("tf" , reason = "test very slow with tf backend" )
466466def test_entropic_partial_gromov_wasserstein (nx ):
467467 rng = np .random .RandomState (42 )
468- n_samples = 20 # nb samples
469- n_noise = 10 # nb of samples (noise)
468+ n_samples = 10 # nb samples
469+ n_noise = 5 # nb of samples (noise)
470470
471471 p = ot .unif (n_samples + n_noise )
472472 psub = ot .unif (n_samples - 5 + n_noise )
@@ -516,6 +516,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
516516 log = True ,
517517 symmetric = list_sym [i ],
518518 verbose = True ,
519+ numItermax = 10 ,
519520 )
520521
521522 resb , logb = ot .gromov .entropic_partial_gromov_wasserstein (
@@ -530,6 +531,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
530531 log = True ,
531532 symmetric = False ,
532533 verbose = True ,
534+ numItermax = 10 ,
533535 )
534536
535537 resb_ = nx .to_numpy (resb )
@@ -552,6 +554,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
552554 log = False ,
553555 symmetric = list_sym [i ],
554556 verbose = True ,
557+ numItermax = 10 ,
555558 )
556559
557560 resb = ot .gromov .entropic_partial_gromov_wasserstein (
@@ -564,6 +567,7 @@ def test_entropic_partial_gromov_wasserstein(nx):
564567 log = False ,
565568 symmetric = False ,
566569 verbose = True ,
570+ numItermax = 10 ,
567571 )
568572
569573 resb_ = nx .to_numpy (resb )
@@ -573,11 +577,25 @@ def test_entropic_partial_gromov_wasserstein(nx):
573577 # tests with different number of samples across spaces
574578 m = 0.5
575579 res , log = ot .gromov .entropic_partial_gromov_wasserstein (
576- C1 , C1sub , p = p , q = psub , reg = 1e4 , m = m , log = True
580+ C1 ,
581+ C1sub ,
582+ p = p ,
583+ q = psub ,
584+ reg = 1e4 ,
585+ m = m ,
586+ log = True ,
587+ numItermax = 10 ,
577588 )
578589
579590 resb , logb = ot .gromov .entropic_partial_gromov_wasserstein (
580- C1b , C1subb , p = pb , q = psubb , reg = 1e4 , m = m , log = True
591+ C1b ,
592+ C1subb ,
593+ p = pb ,
594+ q = psubb ,
595+ reg = 1e4 ,
596+ m = m ,
597+ log = True ,
598+ numItermax = 10 ,
581599 )
582600
583601 resb_ = nx .to_numpy (resb )
@@ -589,10 +607,26 @@ def test_entropic_partial_gromov_wasserstein(nx):
589607 # tests for pGW2
590608 for loss_fun in ["square_loss" , "kl_loss" ]:
591609 w0 , log0 = ot .gromov .entropic_partial_gromov_wasserstein2 (
592- C1 , C2 , p = None , q = q , reg = 1e4 , m = m , loss_fun = loss_fun , log = True
610+ C1 ,
611+ C2 ,
612+ p = None ,
613+ q = q ,
614+ reg = 1e4 ,
615+ m = m ,
616+ loss_fun = loss_fun ,
617+ log = True ,
618+ numItermax = 10 ,
593619 )
594620 w0_val = ot .gromov .entropic_partial_gromov_wasserstein2 (
595- C1b , C2b , p = pb , q = None , reg = 1e4 , m = m , loss_fun = loss_fun , log = False
621+ C1b ,
622+ C2b ,
623+ p = pb ,
624+ q = None ,
625+ reg = 1e4 ,
626+ m = m ,
627+ loss_fun = loss_fun ,
628+ log = False ,
629+ numItermax = 10 ,
596630 )
597631 np .testing .assert_allclose (w0 , w0_val , rtol = 1e-8 )
598632
@@ -666,6 +700,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
666700 log = True ,
667701 symmetric = list_sym [i ],
668702 verbose = True ,
703+ numItermax = 10 ,
669704 )
670705
671706 resb , logb = ot .gromov .entropic_partial_fused_gromov_wasserstein (
@@ -681,6 +716,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
681716 log = True ,
682717 symmetric = False ,
683718 verbose = True ,
719+ numItermax = 10 ,
684720 )
685721
686722 resb_ = nx .to_numpy (resb )
@@ -704,6 +740,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
704740 log = False ,
705741 symmetric = list_sym [i ],
706742 verbose = True ,
743+ numItermax = 10 ,
707744 )
708745
709746 resb = ot .gromov .entropic_partial_fused_gromov_wasserstein (
@@ -717,6 +754,7 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
717754 log = False ,
718755 symmetric = False ,
719756 verbose = True ,
757+ numItermax = 10 ,
720758 )
721759
722760 resb_ = nx .to_numpy (resb )
@@ -726,11 +764,27 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
726764 # tests with different number of samples across spaces
727765 m = 0.5
728766 res , log = ot .gromov .entropic_partial_fused_gromov_wasserstein (
729- M11sub , C1 , C1sub , p = p , q = psub , reg = 1e4 , m = m , log = True
767+ M11sub ,
768+ C1 ,
769+ C1sub ,
770+ p = p ,
771+ q = psub ,
772+ reg = 1e4 ,
773+ m = m ,
774+ log = True ,
775+ numItermax = 10 ,
730776 )
731777
732778 resb , logb = ot .gromov .entropic_partial_fused_gromov_wasserstein (
733- M11subb , C1b , C1subb , p = pb , q = psubb , reg = 1e4 , m = m , log = True
779+ M11subb ,
780+ C1b ,
781+ C1subb ,
782+ p = pb ,
783+ q = psubb ,
784+ reg = 1e4 ,
785+ m = m ,
786+ log = True ,
787+ numItermax = 10 ,
734788 )
735789
736790 resb_ = nx .to_numpy (resb )
@@ -742,9 +796,27 @@ def test_entropic_partial_fused_gromov_wasserstein(nx):
742796 # tests for pGW2
743797 for loss_fun in ["square_loss" , "kl_loss" ]:
744798 w0 , log0 = ot .gromov .entropic_partial_fused_gromov_wasserstein2 (
745- M12 , C1 , C2 , p = None , q = q , reg = 1e4 , m = m , loss_fun = loss_fun , log = True
799+ M12 ,
800+ C1 ,
801+ C2 ,
802+ p = None ,
803+ q = q ,
804+ reg = 1e4 ,
805+ m = m ,
806+ loss_fun = loss_fun ,
807+ log = True ,
808+ numItermax = 10 ,
746809 )
747810 w0_val = ot .gromov .entropic_partial_fused_gromov_wasserstein2 (
748- M12b , C1b , C2b , p = pb , q = None , reg = 1e4 , m = m , loss_fun = loss_fun , log = False
811+ M12b ,
812+ C1b ,
813+ C2b ,
814+ p = pb ,
815+ q = None ,
816+ reg = 1e4 ,
817+ m = m ,
818+ loss_fun = loss_fun ,
819+ log = False ,
820+ numItermax = 10 ,
749821 )
750822 np .testing .assert_allclose (w0 , w0_val , rtol = 1e-8 )
0 commit comments