Source code for ot.lp.solver_1d

# -*- coding: utf-8 -*-
"""
Exact solvers for the 1D Wasserstein distance using cvxopt
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Author: Nicolas Courty <ncourty@irisa.fr>
# Author: Clément Bonet <clement.bonet.mapp@polytechnique.edu>
#
# License: MIT License

import numpy as np
import warnings

from .emd_wrap import emd_1d_sorted
from ..backend import get_backend
from ..utils import list_to_array
from ._network_simplex import center_ot_dual


def quantile_function(qs, cws, xs, return_index=False):
    r"""Computes the quantile function of an empirical distribution

    Parameters
    ----------
    qs: array-like, shape (n,)
        Quantiles at which the quantile function is evaluated
    cws: array-like, shape (m, ...)
        cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
    xs: array-like, shape (n, ...)
        locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions
    return_index: bool

    Returns
    -------
    q: array-like, shape (..., n)
        The quantiles of the distribution
    """
    nx = get_backend(qs, cws)
    n = xs.shape[0]
    if nx.__name__ == "torch":
        # this is to ensure the best performance for torch searchsorted
        # and avoid a warning related to non-contiguous arrays
        cws = cws.movedim(0, -1).contiguous()
        qs = qs.movedim(0, -1).contiguous()
    else:
        cws = cws.T
        qs = qs.T

    idx = nx.clip(nx.searchsorted(cws, qs).T, 0, n - 1)

    if return_index:
        return nx.take_along_axis(xs, idx, axis=0), idx
    else:
        return nx.take_along_axis(xs, idx, axis=0)


[docs] def wasserstein_1d( u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True ): r""" Computes the 1 dimensional OT loss [15] between two (batched) empirical distributions .. math: OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them. This function should be preferred to `emd_1d` whenever the backend is different to numpy, and when gradients over either sample positions or weights are required. Parameters ---------- u_values: array-like, shape (n, ...) locations of the first empirical distribution v_values: array-like, shape (m, ...) locations of the second empirical distribution u_weights: array-like, shape (n, ...), optional weights of the first empirical distribution, if None then uniform weights are used v_weights: array-like, shape (m, ...), optional weights of the second empirical distribution, if None then uniform weights are used p: int, optional order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 require_sort: bool, optional sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True Returns ------- cost: float/array-like, shape (...) the batched EMD References ---------- .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport. """ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: nx = get_backend(u_values, v_values) n = u_values.shape[0] m = v_values.shape[0] if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) if require_sort: u_sorter = nx.argsort(u_values, 0) u_values = nx.take_along_axis(u_values, u_sorter, 0) v_sorter = nx.argsort(v_values, 0) v_values = nx.take_along_axis(v_values, v_sorter, 0) u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) u_cumweights = nx.cumsum(u_weights, 0) v_cumweights = nx.cumsum(v_weights, 0) qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) u_quantiles = quantile_function(qs, u_cumweights, u_values) v_quantiles = quantile_function(qs, v_cumweights, v_values) qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) delta = qs[1:, ...] - qs[:-1, ...] diff_quantiles = nx.abs(u_quantiles - v_quantiles) if p == 1: return nx.sum(delta * diff_quantiles, axis=0) return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)
[docs] def emd_1d( x_a, x_b, a=None, b=None, metric="sqeuclidean", p=1.0, dense=True, log=False, check_marginals=True, ): r"""Solves the Earth Movers distance problem between 1d measures and returns the OT matrix .. math:: \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 where : - d is the metric - :math:`x_a` and :math:`x_b` are the samples - a and b are the sample weights This implementation only supports metrics of the form :math:`d(x, y) = |x - y|^p`. Uses the algorithm detailed in [1]_ Parameters ---------- x_a : ndarray of float64, shape (ns,) or (ns, 1) Source dirac locations (on the real line) x_b : ndarray of float64, shape (nt,) or (ns, 1) Target dirac locations (on the real line) a : ndarray of float64, shape (ns,), optional Source histogram (default is uniform weight) b : ndarray of float64, shape (nt,), optional Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') Metric to be used. Only works with either of the strings `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' dense: boolean, optional (default=True) If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. Due to implementation details, this function runs faster when `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics are used. log: boolean, optional (default=False) If True, returns a dictionary containing the cost. Otherwise returns only the optimal transportation matrix. check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. Returns ------- gamma: ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log: dict If input log is True, a dictionary containing the cost Examples -------- Simple example with obvious solution. The function emd_1d accepts lists and performs automatic conversion to numpy arrays >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> x_a = [2., 0.] >>> x_b = [0., 3.] >>> ot.emd_1d(x_a, x_b, a, b) array([[0. , 0.5], [0.5, 0. ]]) >>> ot.emd_1d(x_a, x_b) array([[0. , 0.5], [0.5, 0. ]]) References ---------- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. See Also -------- ot.lp.emd : EMD for multidimensional distributions ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ x_a, x_b = list_to_array(x_a, x_b) nx = get_backend(x_a, x_b) if a is not None: a = list_to_array(a, nx=nx) if b is not None: b = list_to_array(b, nx=nx) assert ( x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1 ), "emd_1d should only be used with monodimensional data" assert ( x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1 ), "emd_1d should only be used with monodimensional data" if metric not in ["sqeuclidean", "minkowski", "cityblock", "euclidean"]: raise ValueError( "Solver for EMD in 1d only supports metrics " + "from the following list: " + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) # if empty array given then use uniform distributions if a is None or a.ndim == 0 or len(a) == 0: a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] if b is None or b.ndim == 0 or len(b) == 0: b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] # ensure that same mass if check_marginals: np.testing.assert_almost_equal( nx.to_numpy(nx.sum(a, axis=0)), nx.to_numpy(nx.sum(b, axis=0)), err_msg="a and b vector must have the same sum", decimal=6, ) b = b * nx.sum(a) / nx.sum(b) x_a_1d = nx.reshape(x_a, (-1,)) x_b_1d = nx.reshape(x_b, (-1,)) perm_a = nx.argsort(x_a_1d) perm_b = nx.argsort(x_b_1d) G_sorted, indices, cost = emd_1d_sorted( nx.to_numpy(a[perm_a]).astype(np.float64), nx.to_numpy(b[perm_b]).astype(np.float64), nx.to_numpy(x_a_1d[perm_a]).astype(np.float64), nx.to_numpy(x_b_1d[perm_b]).astype(np.float64), metric=metric, p=p, ) G = nx.coo_matrix( G_sorted, perm_a[indices[:, 0]], perm_b[indices[:, 1]], shape=(a.shape[0], b.shape[0]), type_as=x_a, ) if dense: G = nx.todense(G) elif str(nx) == "jax": warnings.warn("JAX does not support sparse matrices, converting to dense") if log: log = {"cost": nx.from_numpy(cost, type_as=x_a)} return G, log return G
[docs] def emd2_1d( x_a, x_b, a=None, b=None, metric="sqeuclidean", p=1.0, dense=True, log=False ): r"""Solves the Earth Movers distance problem between 1d measures and returns the loss .. math:: \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) s.t. \gamma 1 = a, \gamma^T 1= b, \gamma\geq 0 where : - d is the metric - :math:`x_a` and :math:`x_b` are the samples - a and b are the sample weights This implementation only supports metrics of the form :math:`d(x, y) = |x - y|^p`. Uses the algorithm detailed in [1]_ Parameters ---------- x_a : ndarray of float64, shape (ns,) or (ns, 1) Source dirac locations (on the real line) x_b : ndarray of float64, shape (nt,) or (ns, 1) Target dirac locations (on the real line) a : ndarray of float64, shape (ns,), optional Source histogram (default is uniform weight) b : ndarray of float64, shape (nt,), optional Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') Metric to be used. Only works with either of the strings `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' dense: boolean, optional (default=True) If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. Only used if log is set to True. Due to implementation details, this function runs faster when dense is set to False. log: boolean, optional (default=False) If True, returns a dictionary containing the transportation matrix. Otherwise returns only the loss. Returns ------- loss: float Cost associated to the optimal transportation log: dict If input log is True, a dictionary containing the Optimal transportation matrix for the given parameters Examples -------- Simple example with obvious solution. The function emd2_1d accepts lists and performs automatic conversion to numpy arrays >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> x_a = [2., 0.] >>> x_b = [0., 3.] >>> ot.emd2_1d(x_a, x_b, a, b) 0.5 >>> ot.emd2_1d(x_a, x_b) 0.5 References ---------- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. See Also -------- ot.lp.emd2 : EMD for multidimensional distributions ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix instead of the cost) """ # If we do not return G (log==False), then we should not to cast it to dense # (useless overhead) G, log_emd = emd_1d( x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, dense=dense and log, log=True ) cost = log_emd["cost"] if log: log_emd = {"G": G} return cost, log_emd return cost
[docs] def emd_1d_dual_backprop( u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True ): r""" Computes the 1 dimensional OT loss between two (batched) empirical distributions .. math:: OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq and returns the dual potentials and the loss, i.e. such that .. math:: OT_{loss}(u,v) = \int f(x)\mathrm{d}u(x) + \int g(y)\mathrm{d}v(y). .. warning:: This function only works in pytorch or jax as it backpropagates through the `wasserstein_1d` function. Parameters ---------- u_values: array-like, shape (n, ...) locations of the first empirical distribution v_values: array-like, shape (m, ...) locations of the second empirical distribution u_weights: array-like, shape (n, ...), optional weights of the first empirical distribution, if None then uniform weights are used v_weights: array-like, shape (m, ...), optional weights of the second empirical distribution, if None then uniform weights are used p: int, optional order of the ground metric used, should be at least 1, default is 1 require_sort: bool, optional sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to the function, default is True Returns ------- f: array-like shape (n, ...) First dual potential g: array-like shape (m, ...) Second dual potential loss: float/array-like, shape (...) the batched EMD """ nx = get_backend(u_values, v_values, u_weights, v_weights) assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" n = u_values.shape[0] m = v_values.shape[0] # Init weights or broadcast if necessary if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) if nx.__name__ == "torch": u_weights_diff = nx.copy(u_weights) v_weights_diff = nx.copy(v_weights) u_weights_diff.requires_grad_(True) v_weights_diff.requires_grad_(True) cost_output = wasserstein_1d( u_values, v_values, u_weights_diff, v_weights_diff, p=p, require_sort=require_sort, ) loss = cost_output.sum() loss.backward() f, g = center_ot_dual( u_weights_diff.grad.detach(), v_weights_diff.grad.detach(), u_weights, v_weights, ) return f, g, cost_output.detach() # value can not be backward anymore elif nx.__name__ == "jax": import jax def ot_1d(a, b): return wasserstein_1d( u_values, v_values, a, b, p=p, require_sort=require_sort ).sum() f, g = jax.grad(ot_1d, argnums=[0, 1])(u_weights, v_weights) cost_output = wasserstein_1d( u_values, v_values, u_weights, v_weights, p=p, require_sort=require_sort ) f, g = center_ot_dual(f, g, u_weights, v_weights) return f, g, cost_output