|
11 | 11 | # Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
|
12 | 12 | # Alexander Tong <alexander.tong@yale.edu>
|
13 | 13 | # Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
|
| 14 | +# Quang Huy Tran <quang-huy.tran@univ-ubs.fr> |
14 | 15 | #
|
15 | 16 | # License: MIT License
|
16 | 17 |
|
17 | 18 | import warnings
|
18 | 19 |
|
19 | 20 | import numpy as np
|
20 | 21 | from scipy.optimize import fmin_l_bfgs_b
|
| 22 | +from scipy.special import logsumexp |
21 | 23 |
|
22 | 24 | from ot.utils import unif, dist, list_to_array
|
23 | 25 | from .backend import get_backend
|
@@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
|
1684 | 1686 |
|
1685 | 1687 |
|
1686 | 1688 | def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
|
1687 |
| - numIterMax=10000, stopThr=1e-9, verbose=False, |
| 1689 | + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, |
1688 | 1690 | log=False, **kwargs):
|
1689 | 1691 | r'''
|
1690 | 1692 | Solve the entropic regularization optimal transport problem and return the
|
@@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
|
1723 | 1725 | Max number of iterations
|
1724 | 1726 | stopThr : float, optional
|
1725 | 1727 | Stop threshol on error (>0)
|
| 1728 | + isLazy: boolean, optional |
| 1729 | + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) |
| 1730 | + If False, calculate full cost matrix and return outputs of sinkhorn function. |
| 1731 | + batchSize: int or tuple of 2 int, optional |
| 1732 | + Size of the batcheses used to compute the sinkhorn update without memory overhead. |
| 1733 | + When a tuple is provided it sets the size of the left/right batches. |
1726 | 1734 | verbose : bool, optional
|
1727 | 1735 | Print information along iterations
|
1728 | 1736 | log : bool, optional
|
@@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
|
1758 | 1766 |
|
1759 | 1767 | .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
|
1760 | 1768 | '''
|
1761 |
| - |
| 1769 | + ns, nt = X_s.shape[0], X_t.shape[0] |
1762 | 1770 | if a is None:
|
1763 |
| - a = unif(np.shape(X_s)[0]) |
| 1771 | + a = unif(ns) |
1764 | 1772 | if b is None:
|
1765 |
| - b = unif(np.shape(X_t)[0]) |
| 1773 | + b = unif(nt) |
| 1774 | + |
| 1775 | + if isLazy: |
| 1776 | + if log: |
| 1777 | + dict_log = {"err": []} |
1766 | 1778 |
|
1767 |
| - M = dist(X_s, X_t, metric=metric) |
| 1779 | + log_a, log_b = np.log(a), np.log(b) |
| 1780 | + f, g = np.zeros(ns), np.zeros(nt) |
| 1781 | + |
| 1782 | + if isinstance(batchSize, int): |
| 1783 | + bs, bt = batchSize, batchSize |
| 1784 | + elif isinstance(batchSize, tuple) and len(batchSize) == 2: |
| 1785 | + bs, bt = batchSize[0], batchSize[1] |
| 1786 | + else: |
| 1787 | + raise ValueError("Batch size must be in integer or a tuple of two integers") |
| 1788 | + |
| 1789 | + range_s, range_t = range(0, ns, bs), range(0, nt, bt) |
| 1790 | + |
| 1791 | + lse_f = np.zeros(ns) |
| 1792 | + lse_g = np.zeros(nt) |
| 1793 | + |
| 1794 | + for i_ot in range(numIterMax): |
| 1795 | + |
| 1796 | + for i in range_s: |
| 1797 | + M = dist(X_s[i:i + bs, :], X_t, metric=metric) |
| 1798 | + lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1) |
| 1799 | + f = log_a - lse_f |
| 1800 | + |
| 1801 | + for j in range_t: |
| 1802 | + M = dist(X_s, X_t[j:j + bt, :], metric=metric) |
| 1803 | + lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0) |
| 1804 | + g = log_b - lse_g |
| 1805 | + |
| 1806 | + if (i_ot + 1) % 10 == 0: |
| 1807 | + m1 = np.zeros_like(a) |
| 1808 | + for i in range_s: |
| 1809 | + M = dist(X_s[i:i + bs, :], X_t, metric=metric) |
| 1810 | + m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) |
| 1811 | + err = np.abs(m1 - a).sum() |
| 1812 | + if log: |
| 1813 | + dict_log["err"].append(err) |
| 1814 | + |
| 1815 | + if verbose and (i_ot + 1) % 100 == 0: |
| 1816 | + print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) |
| 1817 | + |
| 1818 | + if err <= stopThr: |
| 1819 | + break |
| 1820 | + |
| 1821 | + if log: |
| 1822 | + dict_log["u"] = f |
| 1823 | + dict_log["v"] = g |
| 1824 | + return (f, g, dict_log) |
| 1825 | + else: |
| 1826 | + return (f, g) |
1768 | 1827 |
|
1769 |
| - if log: |
1770 |
| - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) |
1771 |
| - return pi, log |
1772 | 1828 | else:
|
1773 |
| - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) |
1774 |
| - return pi |
| 1829 | + M = dist(X_s, X_t, metric=metric) |
| 1830 | + |
| 1831 | + if log: |
| 1832 | + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) |
| 1833 | + return pi, log |
| 1834 | + else: |
| 1835 | + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) |
| 1836 | + return pi |
1775 | 1837 |
|
1776 | 1838 |
|
1777 | 1839 | def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
|
1778 |
| - verbose=False, log=False, **kwargs): |
| 1840 | + isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): |
1779 | 1841 | r'''
|
1780 | 1842 | Solve the entropic regularization optimal transport problem from empirical
|
1781 | 1843 | data and return the OT loss
|
@@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
|
1814 | 1876 | Max number of iterations
|
1815 | 1877 | stopThr : float, optional
|
1816 | 1878 | Stop threshol on error (>0)
|
| 1879 | + isLazy: boolean, optional |
| 1880 | + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) |
| 1881 | + If False, calculate full cost matrix and return outputs of sinkhorn function. |
| 1882 | + batchSize: int or tuple of 2 int, optional |
| 1883 | + Size of the batcheses used to compute the sinkhorn update without memory overhead. |
| 1884 | + When a tuple is provided it sets the size of the left/right batches. |
1817 | 1885 | verbose : bool, optional
|
1818 | 1886 | Print information along iterations
|
1819 | 1887 | log : bool, optional
|
@@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
|
1850 | 1918 | .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
|
1851 | 1919 | '''
|
1852 | 1920 |
|
| 1921 | + ns, nt = X_s.shape[0], X_t.shape[0] |
1853 | 1922 | if a is None:
|
1854 |
| - a = unif(np.shape(X_s)[0]) |
| 1923 | + a = unif(ns) |
1855 | 1924 | if b is None:
|
1856 |
| - b = unif(np.shape(X_t)[0]) |
| 1925 | + b = unif(nt) |
1857 | 1926 |
|
1858 |
| - M = dist(X_s, X_t, metric=metric) |
| 1927 | + if isLazy: |
| 1928 | + if log: |
| 1929 | + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, |
| 1930 | + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) |
| 1931 | + else: |
| 1932 | + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, |
| 1933 | + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) |
| 1934 | + |
| 1935 | + bs = batchSize if isinstance(batchSize, int) else batchSize[0] |
| 1936 | + range_s = range(0, ns, bs) |
| 1937 | + |
| 1938 | + loss = 0 |
| 1939 | + for i in range_s: |
| 1940 | + M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) |
| 1941 | + pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) |
| 1942 | + loss += np.sum(M_block * pi_block) |
| 1943 | + |
| 1944 | + if log: |
| 1945 | + return loss, dict_log |
| 1946 | + else: |
| 1947 | + return loss |
1859 | 1948 |
|
1860 |
| - if log: |
1861 |
| - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, |
1862 |
| - **kwargs) |
1863 |
| - return sinkhorn_loss, log |
1864 | 1949 | else:
|
1865 |
| - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, |
1866 |
| - **kwargs) |
1867 |
| - return sinkhorn_loss |
| 1950 | + M = dist(X_s, X_t, metric=metric) |
| 1951 | + |
| 1952 | + if log: |
| 1953 | + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, |
| 1954 | + **kwargs) |
| 1955 | + return sinkhorn_loss, log |
| 1956 | + else: |
| 1957 | + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, |
| 1958 | + **kwargs) |
| 1959 | + return sinkhorn_loss |
1868 | 1960 |
|
1869 | 1961 |
|
1870 | 1962 | def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
|
|
0 commit comments