2929
3030def solve (M , a = None , b = None , reg = None , reg_type = "KL" , unbalanced = None ,
3131 unbalanced_type = 'KL' , method = None , n_threads = 1 , max_iter = None , plan_init = None ,
32- potentials_init = None , tol = None , verbose = False ):
32+ potentials_init = None , tol = None , verbose = False , grad = 'autodiff' ):
3333 r"""Solve the discrete optimal transport problem and return :any:`OTResult` object
3434
3535 The function solves the following general optimal transport problem
@@ -79,6 +79,12 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
7979 Tolerance for solution precision, by default None (default values in each solvers)
8080 verbose : bool, optional
8181 Print information in the solver, by default False
82+ grad : str, optional
83+ Type of gradient computation, either or 'autodiff' or 'implicit' used only for
84+ Sinkhorn solver. By default 'autodiff' provides gradients wrt all
85+ outputs (`plan, value, value_linear`) but with important memory cost.
86+ 'implicit' provides gradients only for `value` and and other outputs are
87+ detached. This is useful for memory saving when only the value is needed.
8288
8389 Returns
8490 -------
@@ -134,6 +140,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
134140 # or for original Sinkhorn paper formulation [2]
135141 res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
136142
143+ # Use implicit differentiation for memory saving
144+ res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors
145+ res.value.backward() # only the value is differentiable
146+
147+ Note that by default the Sinkhorn solver uses automatic differentiation to
148+ compute the gradients of the values and plan. This can be changed with the
149+ `grad` parameter. The `implicit` mode computes the implicit gradients only
150+ for the value and the other outputs are detached. This is useful for
151+ memory saving when only the gradient of value is needed.
152+
137153 - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``):
138154
139155 .. math::
@@ -297,6 +313,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
297313
298314 if reg_type .lower () in ['entropy' , 'kl' ]:
299315
316+ if grad == 'implicit' : # if implicit then detach the input
317+ M0 , a0 , b0 = M , a , b
318+ M , a , b = nx .detach (M , a , b )
319+
300320 # default values for sinkhorn
301321 if max_iter is None :
302322 max_iter = 1000
@@ -316,6 +336,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
316336
317337 potentials = (log ['log_u' ], log ['log_v' ])
318338
339+ if grad == 'implicit' : # set the gradient at convergence
340+
341+ value = nx .set_gradients (value , (M0 , a0 , b0 ),
342+ (plan , reg * (potentials [0 ] - potentials [0 ].mean ()), reg * (potentials [1 ] - potentials [1 ].mean ())))
343+
319344 elif reg_type .lower () == 'l2' :
320345
321346 if max_iter is None :
@@ -869,7 +894,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
869894def solve_sample (X_a , X_b , a = None , b = None , metric = 'sqeuclidean' , reg = None , reg_type = "KL" ,
870895 unbalanced = None ,
871896 unbalanced_type = 'KL' , lazy = False , batch_size = None , method = None , n_threads = 1 , max_iter = None , plan_init = None , rank = 100 , scaling = 0.95 ,
872- potentials_init = None , X_init = None , tol = None , verbose = False ):
897+ potentials_init = None , X_init = None , tol = None , verbose = False ,
898+ grad = 'autodiff' ):
873899 r"""Solve the discrete optimal transport problem using the samples in the source and target domains.
874900
875901 The function solves the following general optimal transport problem
@@ -935,6 +961,12 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
935961 Tolerance for solution precision, by default None (default values in each solvers)
936962 verbose : bool, optional
937963 Print information in the solver, by default False
964+ grad : str, optional
965+ Type of gradient computation, either or 'autodiff' or 'implicit' used only for
966+ Sinkhorn solver. By default 'autodiff' provides gradients wrt all
967+ outputs (`plan, value, value_linear`) but with important memory cost.
968+ 'implicit' provides gradients only for `value` and and other outputs are
969+ detached. This is useful for memory saving when only the value is needed.
938970
939971 Returns
940972 -------
@@ -1002,6 +1034,16 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
10021034 # lazy OT plan
10031035 lazy_plan = res.lazy_plan
10041036
1037+ # Use implicit differentiation for memory saving
1038+ res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit')
1039+ res.value.backward() # only the value is differentiable
1040+
1041+ Note that by default the Sinkhorn solver uses automatic differentiation to
1042+ compute the gradients of the values and plan. This can be changed with the
1043+ `grad` parameter. The `implicit` mode computes the implicit gradients only
1044+ for the value and the other outputs are detached. This is useful for
1045+ memory saving when only the gradient of value is needed.
1046+
10051047 We also have a very efficient solver with compiled CPU/CUDA code using
10061048 geomloss/PyKeOps that can be used with the following code:
10071049
@@ -1189,7 +1231,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
11891231 # compute cost matrix M and use solve function
11901232 M = dist (X_a , X_b , metric )
11911233
1192- res = solve (M , a , b , reg , reg_type , unbalanced , unbalanced_type , method , n_threads , max_iter , plan_init , potentials_init , tol , verbose )
1234+ res = solve (M , a , b , reg , reg_type , unbalanced , unbalanced_type , method , n_threads , max_iter , plan_init , potentials_init , tol , verbose , grad )
11931235
11941236 return res
11951237
0 commit comments