Skip to content

Commit f395e58

Browse files
[MRG] fix gpu compatibility of srGW solvers (#596)
* fix gpu compatibility of srgw solvers * update release and pep8
1 parent 336980f commit f395e58

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,4 +354,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
354354

355355
[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.
356356

357-
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
357+
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).

RELEASES.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#### Closed issues
66
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
7-
7+
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
88

99
## 0.9.2
1010
*December 2023*
@@ -671,4 +671,4 @@ It provides the following solvers:
671671
* Optimal transport for domain adaptation with group lasso regularization
672672
* Conditional gradient and Generalized conditional gradient for regularized OT.
673673

674-
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
674+
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.

ot/gromov/_semirelaxed.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
114114
else:
115115
q = nx.sum(G0, 0)
116116
# Check first marginal of G0
117-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
117+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
118118

119119
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
120120

@@ -363,8 +363,8 @@ def semirelaxed_fused_gromov_wasserstein(
363363
G0 = nx.outer(p, q)
364364
else:
365365
q = nx.sum(G0, 0)
366-
# Check marginals of G0
367-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
366+
# Check first marginal of G0
367+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
368368

369369
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
370370

@@ -703,7 +703,7 @@ def entropic_semirelaxed_gromov_wasserstein(
703703
else:
704704
q = nx.sum(G0, 0)
705705
# Check first marginal of G0
706-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
706+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
707707

708708
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
709709

@@ -951,7 +951,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
951951
else:
952952
q = nx.sum(G0, 0)
953953
# Check first marginal of G0
954-
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
954+
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)
955955

956956
constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)
957957

0 commit comments

Comments
 (0)