@@ -344,6 +344,188 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
344344 return W
345345
346346
347+ def bures_wasserstein_barycenter (m , C , weights = None , num_iter = 1000 , eps = 1e-7 , log = False ):
348+ r"""Return OT linear operator between samples.
349+
350+ The function estimates the optimal barycenter of the
351+ empirical distributions. This is equivalent to resolving the fixed point
352+ algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
353+ :ref:`[1] <references-OT-mapping-linear-barycenter>`.
354+
355+ The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
356+ where :
357+
358+ .. math::
359+ \mu_b = \sum_{i=1}^n w_i \mu_i
360+
361+ And the barycentric covariance is the solution of the following fixed-point algorithm:
362+
363+ .. math::
364+ \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
365+
366+
367+ Parameters
368+ ----------
369+ m : array-like (k,d)
370+ mean of k distributions
371+ C : array-like (k,d,d)
372+ covariance of k distributions
373+ weights : array-like (k), optional
374+ weights for each distribution
375+ num_iter : int, optional
376+ number of iteration for the fixed point algorithm
377+ eps : float, optional
378+ tolerance for the fixed point algorithm
379+ log : bool, optional
380+ record log if True
381+
382+
383+ Returns
384+ -------
385+ mb : (d,) array-like
386+ mean of the barycenter
387+ Cb : (d, d) array-like
388+ covariance of the barycenter
389+ log : dict
390+ log dictionary return only if log==True in parameters
391+
392+
393+ .. _references-OT-mapping-linear-barycenter:
394+ References
395+ ----------
396+ .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
397+ SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
398+ 2011.
399+ """
400+ nx = get_backend (* C , * m ,)
401+
402+ # Compute the mean barycenter
403+ mb = nx .mean (m )
404+
405+ # Init the covariance barycenter
406+ Cb = nx .mean (C , axis = 0 )
407+
408+ if weights is None :
409+ weights = nx .ones (len (C ), type_as = C [0 ]) / len (C )
410+
411+ for it in range (num_iter ):
412+ # fixed point update
413+ Cb12 = nx .sqrtm (Cb )
414+
415+ Cnew = Cb12 @ C @ Cb12
416+ C_ = []
417+ for i in range (len (C )):
418+ C_ .append (nx .sqrtm (Cnew [i ]))
419+ Cnew = nx .stack (C_ , axis = 0 )
420+ Cnew *= weights [:, None , None ]
421+ Cnew = nx .sum (Cnew , axis = 0 )
422+
423+ # check convergence
424+ diff = nx .norm (Cb - Cnew )
425+ if diff <= eps :
426+ break
427+ Cb = Cnew
428+ else :
429+ print ("Dit not converge." )
430+
431+ if log :
432+ log = {}
433+ log ['num_iter' ] = it
434+ log ['final_diff' ] = diff
435+ return mb , Cb , log
436+ else :
437+ return mb , Cb
438+
439+
440+ def empirical_bures_wasserstein_barycenter (
441+ X , reg = 1e-6 , weights = None , num_iter = 1000 , eps = 1e-7 ,
442+ w = None , bias = True , log = False
443+ ):
444+ r"""Return OT linear operator between samples.
445+
446+ The function estimates the optimal barycenter of the
447+ empirical distributions. This is equivalent to resolving the fixed point
448+ algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
449+ :ref:`[1] <references-OT-mapping-linear-barycenter>`.
450+
451+ The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
452+ where :
453+
454+ .. math::
455+ \mu_b = \sum_{i=1}^n w_i \mu_i
456+
457+ And the barycentric covariance is the solution of the following fixed-point algorithm:
458+
459+ .. math::
460+ \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
461+
462+
463+ Parameters
464+ ----------
465+ X : list of array-like (n,d)
466+ samples in each distribution
467+ reg : float,optional
468+ regularization added to the diagonals of covariances (>0)
469+ weights : array-like (n,), optional
470+ weights for each distribution
471+ num_iter : int, optional
472+ number of iteration for the fixed point algorithm
473+ eps : float, optional
474+ tolerance for the fixed point algorithm
475+ w : list of array-like (n,), optional
476+ weights for each sample in each distribution
477+ bias: boolean, optional
478+ estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
479+ log : bool, optional
480+ record log if True
481+
482+
483+ Returns
484+ -------
485+ mb : (d,) array-like
486+ mean of the barycenter
487+ Cb : (d, d) array-like
488+ covariance of the barycenter
489+ log : dict
490+ log dictionary return only if log==True in parameters
491+
492+
493+ .. _references-OT-mapping-linear-barycenter:
494+ References
495+ ----------
496+ .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
497+ SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
498+ 2011.
499+ """
500+ X = list_to_array (* X )
501+ nx = get_backend (* X )
502+
503+ k = len (X )
504+ d = [X [i ].shape [1 ] for i in range (k )]
505+
506+ if bias :
507+ m = [nx .mean (X [i ], axis = 0 )[None , :] for i in range (k )]
508+ X = [X [i ] - m [i ] for i in range (k )]
509+ else :
510+ m = [nx .zeros ((1 , d [i ]), type_as = X [i ]) for i in range (k )]
511+
512+ if w is None :
513+ w = [nx .ones ((X [i ].shape [0 ], 1 ), type_as = X [i ]) / X [i ].shape [0 ] for i in range (k )]
514+
515+ C = [
516+ nx .dot ((X [i ] * w [i ]).T , X [i ]) / nx .sum (w [i ]) + reg * nx .eye (d [i ], type_as = X [i ])
517+ for i in range (k )
518+ ]
519+ m = nx .stack (m , axis = 0 )
520+ C = nx .stack (C , axis = 0 )
521+ if log :
522+ mb , Cb , log = bures_wasserstein_barycenter (m , C , weights = weights , num_iter = num_iter , eps = eps , log = log )
523+ return mb , Cb , log
524+ else :
525+ mb , Cb = bures_wasserstein_barycenter (m , C , weights = weights , num_iter = num_iter , eps = eps , log = log )
526+ return mb , Cb
527+
528+
347529def gaussian_gromov_wasserstein_distance (Cov_s , Cov_t , log = False ):
348530 r""" Return the Gaussian Gromov-Wasserstein value from [57].
349531
0 commit comments