Source code for ot.lp.solver_1d

# -*- coding: utf-8 -*-
"""
Exact solvers for the 1D Wasserstein distance using cvxopt
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

import numpy as np
import warnings

from .emd_wrap import emd_1d_sorted
from ..backend import get_backend
from ..utils import list_to_array


def quantile_function(qs, cws, xs):
    r"""Computes the quantile function of an empirical distribution

    Parameters
    ----------
    qs: array-like, shape (n,)
        Quantiles at which the quantile function is evaluated
    cws: array-like, shape (m, ...)
        cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
    xs: array-like, shape (n, ...)
        locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions

    Returns
    -------
    q: array-like, shape (..., n)
        The quantiles of the distribution
    """
    nx = get_backend(qs, cws)
    n = xs.shape[0]
    if nx.__name__ == "torch":
        # this is to ensure the best performance for torch searchsorted
        # and avoid a warning related to non-contiguous arrays
        cws = cws.T.contiguous()
        qs = qs.T.contiguous()
    else:
        cws = cws.T
        qs = qs.T
    idx = nx.searchsorted(cws, qs).T
    return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0)


[docs] def wasserstein_1d( u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True ): r""" Computes the 1 dimensional OT loss [15] between two (batched) empirical distributions .. math: OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them. This function should be preferred to `emd_1d` whenever the backend is different to numpy, and when gradients over either sample positions or weights are required. Parameters ---------- u_values: array-like, shape (n, ...) locations of the first empirical distribution v_values: array-like, shape (m, ...) locations of the second empirical distribution u_weights: array-like, shape (n, ...), optional weights of the first empirical distribution, if None then uniform weights are used v_weights: array-like, shape (m, ...), optional weights of the second empirical distribution, if None then uniform weights are used p: int, optional order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 require_sort: bool, optional sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True Returns ------- cost: float/array-like, shape (...) the batched EMD References ---------- .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport. """ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: nx = get_backend(u_values, v_values) n = u_values.shape[0] m = v_values.shape[0] if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) if require_sort: u_sorter = nx.argsort(u_values, 0) u_values = nx.take_along_axis(u_values, u_sorter, 0) v_sorter = nx.argsort(v_values, 0) v_values = nx.take_along_axis(v_values, v_sorter, 0) u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) u_cumweights = nx.cumsum(u_weights, 0) v_cumweights = nx.cumsum(v_weights, 0) qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) u_quantiles = quantile_function(qs, u_cumweights, u_values) v_quantiles = quantile_function(qs, v_cumweights, v_values) qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) delta = qs[1:, ...] - qs[:-1, ...] diff_quantiles = nx.abs(u_quantiles - v_quantiles) if p == 1: return nx.sum(delta * diff_quantiles, axis=0) return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
[docs] def emd_1d( x_a, x_b, a=None, b=None, metric="sqeuclidean", p=1.0, dense=True, log=False, check_marginals=True, ): r"""Solves the Earth Movers distance problem between 1d measures and returns the OT matrix .. math:: \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 where : - d is the metric - x_a and x_b are the samples - a and b are the sample weights This implementation only supports metrics of the form :math:`d(x, y) = |x - y|^p`. Uses the algorithm detailed in [1]_ Parameters ---------- x_a : (ns,) or (ns, 1) ndarray, float64 Source dirac locations (on the real line) x_b : (nt,) or (ns, 1) ndarray, float64 Target dirac locations (on the real line) a : (ns,) ndarray, float64, optional Source histogram (default is uniform weight) b : (nt,) ndarray, float64, optional Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') Metric to be used. Only works with either of the strings `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' dense: boolean, optional (default=True) If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. Due to implementation details, this function runs faster when `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics are used. log: boolean, optional (default=False) If True, returns a dictionary containing the cost. Otherwise returns only the optimal transportation matrix. check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. Returns ------- gamma: (ns, nt) ndarray Optimal transportation matrix for the given parameters log: dict If input log is True, a dictionary containing the cost Examples -------- Simple example with obvious solution. The function emd_1d accepts lists and performs automatic conversion to numpy arrays >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> x_a = [2., 0.] >>> x_b = [0., 3.] >>> ot.emd_1d(x_a, x_b, a, b) array([[0. , 0.5], [0.5, 0. ]]) >>> ot.emd_1d(x_a, x_b) array([[0. , 0.5], [0.5, 0. ]]) References ---------- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. See Also -------- ot.lp.emd : EMD for multidimensional distributions ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ x_a, x_b = list_to_array(x_a, x_b) nx = get_backend(x_a, x_b) if a is not None: a = list_to_array(a, nx=nx) if b is not None: b = list_to_array(b, nx=nx) assert ( x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1 ), "emd_1d should only be used with monodimensional data" assert ( x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1 ), "emd_1d should only be used with monodimensional data" if metric not in ["sqeuclidean", "minkowski", "cityblock", "euclidean"]: raise ValueError( "Solver for EMD in 1d only supports metrics " + "from the following list: " + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) # if empty array given then use uniform distributions if a is None or a.ndim == 0 or len(a) == 0: a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] if b is None or b.ndim == 0 or len(b) == 0: b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] # ensure that same mass if check_marginals: np.testing.assert_almost_equal( nx.to_numpy(nx.sum(a, axis=0)), nx.to_numpy(nx.sum(b, axis=0)), err_msg="a and b vector must have the same sum", decimal=6, ) b = b * nx.sum(a) / nx.sum(b) x_a_1d = nx.reshape(x_a, (-1,)) x_b_1d = nx.reshape(x_b, (-1,)) perm_a = nx.argsort(x_a_1d) perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( nx.to_numpy(a[perm_a]).astype(np.float64), nx.to_numpy(b[perm_b]).astype(np.float64), nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p, ) G = nx.coo_matrix( G_sorted, perm_a[indices[:, 0]], perm_b[indices[:, 1]], shape=(a.shape[0], b.shape[0]), type_as=x_a, ) if dense: G = nx.todense(G) elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {"cost": nx.from_numpy(cost, type_as=x_a)} return G, log return G
[docs] def emd2_1d( x_a, x_b, a=None, b=None, metric="sqeuclidean", p=1.0, dense=True, log=False ): r"""Solves the Earth Movers distance problem between 1d measures and returns the loss .. math:: \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 where : - d is the metric - x_a and x_b are the samples - a and b are the sample weights This implementation only supports metrics of the form :math:`d(x, y) = |x - y|^p`. Uses the algorithm detailed in [1]_ Parameters ---------- x_a : (ns,) or (ns, 1) ndarray, float64 Source dirac locations (on the real line) x_b : (nt,) or (ns, 1) ndarray, float64 Target dirac locations (on the real line) a : (ns,) ndarray, float64, optional Source histogram (default is uniform weight) b : (nt,) ndarray, float64, optional Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') Metric to be used. Only works with either of the strings `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' dense: boolean, optional (default=True) If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. Only used if log is set to True. Due to implementation details, this function runs faster when dense is set to False. log: boolean, optional (default=False) If True, returns a dictionary containing the transportation matrix. Otherwise returns only the loss. Returns ------- loss: float Cost associated to the optimal transportation log: dict If input log is True, a dictionary containing the Optimal transportation matrix for the given parameters Examples -------- Simple example with obvious solution. The function emd2_1d accepts lists and performs automatic conversion to numpy arrays >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> x_a = [2., 0.] >>> x_b = [0., 3.] >>> ot.emd2_1d(x_a, x_b, a, b) 0.5 >>> ot.emd2_1d(x_a, x_b) 0.5 References ---------- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. See Also -------- ot.lp.emd2 : EMD for multidimensional distributions ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix instead of the cost) """ # If we do not return G (log==False), then we should not to cast it to dense # (useless overhead) G, log_emd = emd_1d( x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, dense=dense and log, log=True ) cost = log_emd["cost"] if log: log_emd = {"G": G} return cost, log_emd return cost
def roll_cols(M, shifts): r""" Utils functions which allow to shift the order of each row of a 2d matrix Parameters ---------- M : (nr, nc) ndarray Matrix to shift shifts: int or (nr,) ndarray Returns ------- Shifted array Examples -------- >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) >>> roll_cols(M, 2) array([[2, 3, 1], [5, 6, 4], [8, 9, 7]]) >>> roll_cols(M, np.array([[1],[2],[1]])) array([[3, 1, 2], [5, 6, 4], [9, 7, 8]]) References ---------- https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch """ nx = get_backend(M) n_rows, n_cols = M.shape arange1 = nx.tile( nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1) ) arange2 = (arange1 - shifts) % n_cols return nx.take_along_axis(M, arange2, 1) def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): r"""Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) Parameters ---------- theta: array-like, shape (n_batch, n) Cuts on the circle u_values: array-like, shape (n_batch, n) locations of the first empirical distribution v_values: array-like, shape (n_batch, n) locations of the second empirical distribution u_cdf: array-like, shape (n_batch, n) cdf of the first empirical distribution v_cdf: array-like, shape (n_batch, n) cdf of the second empirical distribution p: float, optional = 2 Power p used for computing the Wasserstein distance Returns ------- dCp: array-like, shape (n_batch, 1) The batched right derivative dCm: array-like, shape (n_batch, 1) The batched left derivative References --------- .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. """ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) v_values = nx.copy(v_values) n = u_values.shape[-1] m_batch, m = v_values.shape v_cdf_theta = v_cdf - (theta - nx.floor(theta)) mask_p = v_cdf_theta >= 0 mask_n = v_cdf_theta < 0 v_values[mask_n] += nx.floor(theta)[mask_n] + 1 v_values[mask_p] += nx.floor(theta)[mask_p] if nx.any(mask_n) and nx.any(mask_p): v_cdf_theta[mask_n] += 1 v_cdf_theta2 = nx.copy(v_cdf_theta) v_cdf_theta2[mask_n] = np.inf shift = -nx.argmin(v_cdf_theta2, axis=-1) v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) v_values = nx.concatenate( [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 ) if nx.__name__ == "torch": # this is to ensure the best performance for torch searchsorted # and avoid a warning related to non-contiguous arrays u_cdf = u_cdf.contiguous() v_cdf_theta = v_cdf_theta.contiguous() # quantiles of F_u evaluated in F_v^\theta u_index = nx.searchsorted(u_cdf, v_cdf_theta) u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) # Deal with 1 u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) u_valuesm = nx.concatenate( [u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1 ) if nx.__name__ == "torch": # this is to ensure the best performance for torch searchsorted # and avoid a warning related to non-contiguous arrays u_cdfm = u_cdfm.contiguous() v_cdf_theta = v_cdf_theta.contiguous() u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) dCp = nx.sum( nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1, ) dCm = nx.sum( nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1, ) return dCp.reshape(-1, 1), dCm.reshape(-1, 1) def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): r"""Computes the the cost (Equation (6.2) of [1]) Parameters ---------- theta: array-like, shape (n_batch, n) Cuts on the circle u_values: array-like, shape (n_batch, n) locations of the first empirical distribution v_values: array-like, shape (n_batch, n) locations of the second empirical distribution u_cdf: array-like, shape (n_batch, n) cdf of the first empirical distribution v_cdf: array-like, shape (n_batch, n) cdf of the second empirical distribution p: float, optional = 2 Power p used for computing the Wasserstein distance Returns ------- ot_cost: array-like, shape (n_batch,) OT cost evaluated at theta References --------- .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. """ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) v_values = nx.copy(v_values) m_batch, m = v_values.shape n_batch, n = u_values.shape v_cdf_theta = v_cdf - (theta - nx.floor(theta)) mask_p = v_cdf_theta >= 0 mask_n = v_cdf_theta < 0 v_values[mask_n] += nx.floor(theta)[mask_n] + 1 v_values[mask_p] += nx.floor(theta)[mask_p] if nx.any(mask_n) and nx.any(mask_p): v_cdf_theta[mask_n] += 1 # Put negative values at the end v_cdf_theta2 = nx.copy(v_cdf_theta) v_cdf_theta2[mask_n] = np.inf shift = -nx.argmin(v_cdf_theta2, axis=-1) v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) v_values = nx.concatenate( [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 ) # Compute absciss cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] if nx.__name__ == "torch": # this is to ensure the best performance for torch searchsorted # and avoid a warning related to non-contiguous arrays u_cdf = u_cdf.contiguous() v_cdf_theta = v_cdf_theta.contiguous() cdf_axis = cdf_axis.contiguous() # Compute icdf u_index = nx.searchsorted(u_cdf, cdf_axis) u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) v_values = nx.concatenate( [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 ) v_index = nx.searchsorted(v_cdf_theta, cdf_axis) v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) if p == 1: ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) else: ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) return ot_cost
[docs] def binary_search_circle( u_values, v_values, u_weights=None, v_weights=None, p=1, Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True, log=False, ): r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates using e.g. the atan2 function. .. math:: W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q where: - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} using e.g. ot.utils.get_coordinate_circle(x) The function runs on backend but tensorflow and jax are not supported. Parameters ---------- u_values : ndarray, shape (n, ...) samples in the source domain (coordinates on [0,1[) v_values : ndarray, shape (n, ...) samples in the target domain (coordinates on [0,1[) u_weights : ndarray, shape (n, ...), optional samples weights in the source domain v_weights : ndarray, shape (n, ...), optional samples weights in the target domain p : float, optional (default=1) Power p used for computing the Wasserstein distance Lm : int, optional Lower bound dC Lp : int, optional Upper bound dC tm: float, optional Lower bound theta tp: float, optional Upper bound theta eps: float, optional Stopping condition require_sort: bool, optional If True, sort the values. log: bool, optional If True, returns also the optimal theta Returns ------- loss: float Cost associated to the optimal transportation log: dict, optional log dictionary returned only if log==True in parameters Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 >>> v = np.array([[0.4,0.5,0.7]])%1 >>> binary_search_circle(u.T, v.T, p=1) array([0.1]) References ---------- .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html """ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: nx = get_backend(u_values, v_values) n = u_values.shape[0] m = v_values.shape[0] if len(u_values.shape) == 1: u_values = nx.reshape(u_values, (n, 1)) if len(v_values.shape) == 1: v_values = nx.reshape(v_values, (m, 1)) if u_values.shape[1] != v_values.shape[1]: raise ValueError( "u and v must have the same number of batches {} and {} respectively given".format( u_values.shape[1], v_values.shape[1] ) ) u_values = u_values % 1 v_values = v_values % 1 if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) if require_sort: u_sorter = nx.argsort(u_values, 0) u_values = nx.take_along_axis(u_values, u_sorter, 0) v_sorter = nx.argsort(v_values, 0) v_values = nx.take_along_axis(v_values, v_sorter, 0) u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) u_cdf = nx.cumsum(u_weights, 0).T v_cdf = nx.cumsum(v_weights, 0).T u_values = u_values.T v_values = v_values.T L = max(Lm, Lp) tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) tm = nx.tile(tm, (1, m)) tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) tp = nx.tile(tp, (1, m)) tc = (tm + tp) / 2 done = nx.zeros((u_values.shape[0], m)) cpt = 0 while nx.any(1 - done): cpt += 1 dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) done = ((dCp * dCm) <= 0) * 1 mask = ((tp - tm) < eps / L) * (1 - done) if nx.any(mask): # can probably be improved by computing only relevant values dCptp, dCmtp = derivative_cost_on_circle( tp, u_values, v_values, u_cdf, v_cdf, p ) dCptm, dCmtm = derivative_cost_on_circle( tm, u_values, v_values, u_cdf, v_cdf, p ) Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape( -1, 1 ) Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape( -1, 1 ) mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) tc[mask_end > 0] = ( (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) )[mask_end > 0] done[nx.prod(mask, axis=-1) > 0] = 1 elif nx.any(1 - done): tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] tc[((1 - mask) * (1 - done)) > 0] = ( tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0] ) / 2 w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) if log: return w, {"optimal_theta": tc[:, 0]} return w
def wasserstein1_circle( u_values, v_values, u_weights=None, v_weights=None, require_sort=True ): r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates using e.g. the atan2 function. The function runs on backend but tensorflow and jax are not supported. .. math:: W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t Parameters ---------- u_values : ndarray, shape (n, ...) samples in the source domain (coordinates on [0,1[) v_values : ndarray, shape (n, ...) samples in the target domain (coordinates on [0,1[) u_weights : ndarray, shape (n, ...), optional samples weights in the source domain v_weights : ndarray, shape (n, ...), optional samples weights in the target domain require_sort: bool, optional If True, sort the values. Returns ------- loss: float Cost associated to the optimal transportation Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 >>> v = np.array([[0.4,0.5,0.7]])%1 >>> wasserstein1_circle(u.T, v.T) array([0.1]) References ---------- .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ """ if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: nx = get_backend(u_values, v_values) n = u_values.shape[0] m = v_values.shape[0] if len(u_values.shape) == 1: u_values = nx.reshape(u_values, (n, 1)) if len(v_values.shape) == 1: v_values = nx.reshape(v_values, (m, 1)) if u_values.shape[1] != v_values.shape[1]: raise ValueError( "u and v must have the same number of batchs {} and {} respectively given".format( u_values.shape[1], v_values.shape[1] ) ) u_values = u_values % 1 v_values = v_values % 1 if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) if require_sort: u_sorter = nx.argsort(u_values, 0) u_values = nx.take_along_axis(u_values, u_sorter, 0) v_sorter = nx.argsort(v_values, 0) v_values = nx.take_along_axis(v_values, v_sorter, 0) u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) cdf_diff = nx.cumsum( nx.take_along_axis( nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0 ), 0, ) cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) delta = values_sorted[1:, ...] - values_sorted[:-1, ...] weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 sum_weights[sum_weights < 0] = np.inf inds = nx.argmin(sum_weights, axis=0) levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)
[docs] def wasserstein_circle( u_values, v_values, u_weights=None, v_weights=None, p=1, Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True, ): r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or the binary search algorithm proposed in [44] otherwise. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates using e.g. the atan2 function. General loss returned: .. math:: OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q For p=1, [45] .. math:: W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} using e.g. ot.utils.get_coordinate_circle(x) The function runs on backend but tensorflow and jax are not supported. Parameters ---------- u_values : ndarray, shape (n, ...) samples in the source domain (coordinates on [0,1[) v_values : ndarray, shape (n, ...) samples in the target domain (coordinates on [0,1[) u_weights : ndarray, shape (n, ...), optional samples weights in the source domain v_weights : ndarray, shape (n, ...), optional samples weights in the target domain p : float, optional (default=1) Power p used for computing the Wasserstein distance Lm : int, optional Lower bound dC. For p>1. Lp : int, optional Upper bound dC. For p>1. tm: float, optional Lower bound theta. For p>1. tp: float, optional Upper bound theta. For p>1. eps: float, optional Stopping condition. For p>1. require_sort: bool, optional If True, sort the values. Returns ------- loss: float Cost associated to the optimal transportation Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 >>> v = np.array([[0.4,0.5,0.7]])%1 >>> wasserstein_circle(u.T, v.T) array([0.1]) References ---------- .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. """ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) if p == 1: return wasserstein1_circle( u_values, v_values, u_weights, v_weights, require_sort ) return binary_search_circle( u_values, v_values, u_weights, v_weights, p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, require_sort=require_sort, )
[docs] def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates using e.g. the atan2 function. .. math:: W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} where: - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, using e.g. ot.utils.get_coordinate_circle(x) Parameters ---------- u_values: ndarray, shape (n, ...) Samples u_weights : ndarray, shape (n, ...), optional samples weights in the source domain Returns ------- loss: float Cost associated to the optimal transportation Examples -------- >>> x0 = np.array([[0], [0.2], [0.4]]) >>> semidiscrete_wasserstein2_unif_circle(x0) array([0.02111111]) References ---------- .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. """ if u_weights is not None: nx = get_backend(u_values, u_weights) else: nx = get_backend(u_values) n = u_values.shape[0] u_values = u_values % 1 if len(u_values.shape) == 1: u_values = nx.reshape(u_values, (n, 1)) if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) u_values = nx.sort(u_values, 0) u_cdf = nx.cumsum(u_weights, 0) u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) cpt1 = nx.sum(u_weights * u_values**2, axis=0) u_mean = nx.sum(u_weights * u_values, axis=0) ns = 1 - u_weights - 2 * u_cdf[:-1] cpt2 = nx.sum(u_values * u_weights * ns, axis=0) return cpt1 - u_mean**2 + cpt2 + 1 / 12