# Source code for ot.lp.solver_1d

# -*- coding: utf-8 -*-
"""
Exact solvers for the 1D Wasserstein distance using cvxopt
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Author: Nicolas Courty <ncourty@irisa.fr>
#

import numpy as np
import warnings

from .emd_wrap import emd_1d_sorted
from ..backend import get_backend
from ..utils import list_to_array

def quantile_function(qs, cws, xs):
r""" Computes the quantile function of an empirical distribution

Parameters
----------
qs: array-like, shape (n,)
Quantiles at which the quantile function is evaluated
cws: array-like, shape (m, ...)
cumulative weights of the 1D empirical distribution, if batched, must be similar to xs
xs: array-like, shape (n, ...)
locations of the 1D empirical distribution, batched against the xs.ndim - 1 first dimensions

Returns
-------
q: array-like, shape (..., n)
The quantiles of the distribution
"""
nx = get_backend(qs, cws)
n = xs.shape[0]
if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
# and avoid a warning related to non-contiguous arrays
cws = cws.T.contiguous()
qs = qs.T.contiguous()
else:
cws = cws.T
qs = qs.T
idx = nx.searchsorted(cws, qs).T
return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0)

[docs]
def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True):
r"""
Computes the 1 dimensional OT loss [15] between two (batched) empirical
distributions

.. math:
OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq

It is formally the p-Wasserstein distance raised to the power p.
We do so in a vectorized way by first building the individual quantile functions then integrating them.

This function should be preferred to emd_1d whenever the backend is
different to numpy, and when gradients over
either sample positions or weights are required.

Parameters
----------
u_values: array-like, shape (n, ...)
locations of the first empirical distribution
v_values: array-like, shape (m, ...)
locations of the second empirical distribution
u_weights: array-like, shape (n, ...), optional
weights of the first empirical distribution, if None then uniform weights are used
v_weights: array-like, shape (m, ...), optional
weights of the second empirical distribution, if None then uniform weights are used
p: int, optional
order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1
require_sort: bool, optional
sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to
the function, default is True

Returns
-------
cost: float/array-like, shape (...)
the batched EMD

References
----------
.. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport.

"""

assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)

if u_weights is not None and v_weights is not None:
nx = get_backend(u_values, v_values, u_weights, v_weights)
else:
nx = get_backend(u_values, v_values)

n = u_values.shape[0]
m = v_values.shape[0]

if u_weights is None:
u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)

if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_values = nx.take_along_axis(u_values, u_sorter, 0)

v_sorter = nx.argsort(v_values, 0)
v_values = nx.take_along_axis(v_values, v_sorter, 0)

u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
v_weights = nx.take_along_axis(v_weights, v_sorter, 0)

u_cumweights = nx.cumsum(u_weights, 0)
v_cumweights = nx.cumsum(v_weights, 0)

qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0)
u_quantiles = quantile_function(qs, u_cumweights, u_values)
v_quantiles = quantile_function(qs, v_cumweights, v_values)
qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)])
delta = qs[1:, ...] - qs[:-1, ...]
diff_quantiles = nx.abs(u_quantiles - v_quantiles)

if p == 1:
return nx.sum(delta * diff_quantiles, axis=0)
return nx.sum(delta * nx.power(diff_quantiles, p), axis=0)

[docs]
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log=False, check_marginals=True):
r"""Solves the Earth Movers distance problem between 1d measures and returns
the OT matrix

.. math::
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])

s.t. \gamma 1 = a,
\gamma^T 1= b,
\gamma\geq 0

where :

- d is the metric
- x_a and x_b are the samples
- a and b are the sample weights

When 'minkowski' is used as a metric, :math:d(x, y) = |x - y|^p.

Uses the algorithm detailed in [1]_

Parameters
----------
x_a : (ns,) or (ns, 1) ndarray, float64
Source dirac locations (on the real line)
x_b : (nt,) or (ns, 1) ndarray, float64
Target dirac locations (on the real line)
a : (ns,) ndarray, float64, optional
Source histogram (default is uniform weight)
b : (nt,) ndarray, float64, optional
Target histogram (default is uniform weight)
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:ot.dist are accepted.
Due to implementation details, this function runs faster when
'sqeuclidean', 'cityblock',  or 'euclidean' metrics are used.
p: float, optional (default=1.0)
The p-norm to apply for if metric='minkowski'
dense: boolean, optional (default=True)
If True, returns math:\gamma as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's coo_matrix
format. Due to implementation details, this function runs faster when
'sqeuclidean', 'minkowski', 'cityblock',  or 'euclidean' metrics
are used.
log: boolean, optional (default=False)
If True, returns a dictionary containing the cost.
Otherwise returns only the optimal transportation matrix.
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.

Returns
-------
gamma: (ns, nt) ndarray
Optimal transportation matrix for the given parameters
log: dict
If input log is True, a dictionary containing the cost

Examples
--------

Simple example with obvious solution. The function emd_1d accepts lists and
performs automatic conversion to numpy arrays

>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> x_a = [2., 0.]
>>> x_b = [0., 3.]
>>> ot.emd_1d(x_a, x_b, a, b)
array([[0. , 0.5],
[0.5, 0. ]])
>>> ot.emd_1d(x_a, x_b)
array([[0. , 0.5],
[0.5, 0. ]])

References
----------

.. [1]  Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.

--------
ot.lp.emd : EMD for multidimensional distributions
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
transportation matrix)
"""
x_a, x_b = list_to_array(x_a, x_b)
nx = get_backend(x_a, x_b)
if a is not None:
a = list_to_array(a, nx=nx)
if b is not None:
b = list_to_array(b, nx=nx)

assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
"emd_1d should only be used with monodimensional data"
assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
"emd_1d should only be used with monodimensional data"

# if empty array given then use uniform distributions
if a is None or a.ndim == 0 or len(a) == 0:
a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0]
if b is None or b.ndim == 0 or len(b) == 0:
b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]

# ensure that same mass
if check_marginals:
np.testing.assert_almost_equal(
nx.to_numpy(nx.sum(a, axis=0)),
nx.to_numpy(nx.sum(b, axis=0)),
err_msg='a and b vector must have the same sum',
decimal=6
)
b = b * nx.sum(a) / nx.sum(b)

x_a_1d = nx.reshape(x_a, (-1,))
x_b_1d = nx.reshape(x_b, (-1,))
perm_a = nx.argsort(x_a_1d)
perm_b = nx.argsort(x_b_1d)

G_sorted, indices, cost = emd_1d_sorted(
nx.to_numpy(a[perm_a]).astype(np.float64),
nx.to_numpy(b[perm_b]).astype(np.float64),
nx.to_numpy(x_a_1d[perm_a]).astype(np.float64),
nx.to_numpy(x_b_1d[perm_b]).astype(np.float64),
metric=metric, p=p
)

G = nx.coo_matrix(
G_sorted,
perm_a[indices[:, 0]],
perm_b[indices[:, 1]],
shape=(a.shape[0], b.shape[0]),
type_as=x_a
)
if dense:
G = nx.todense(G)
elif str(nx) == "jax":
warnings.warn("JAX does not support sparse matrices, converting to dense")
if log:
log = {'cost': nx.from_numpy(cost, type_as=x_a)}
return G, log
return G

[docs]
def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log=False):
r"""Solves the Earth Movers distance problem between 1d measures and returns
the loss

.. math::
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])

s.t. \gamma 1 = a,
\gamma^T 1= b,
\gamma\geq 0

where :

- d is the metric
- x_a and x_b are the samples
- a and b are the sample weights

When 'minkowski' is used as a metric, :math:d(x, y) = |x - y|^p.

Uses the algorithm detailed in [1]_

Parameters
----------
x_a : (ns,) or (ns, 1) ndarray, float64
Source dirac locations (on the real line)
x_b : (nt,) or (ns, 1) ndarray, float64
Target dirac locations (on the real line)
a : (ns,) ndarray, float64, optional
Source histogram (default is uniform weight)
b : (nt,) ndarray, float64, optional
Target histogram (default is uniform weight)
metric: str, optional (default='sqeuclidean')
Metric to be used. Only strings listed in :func:ot.dist are accepted.
Due to implementation details, this function runs faster when
'sqeuclidean', 'minkowski', 'cityblock',  or 'euclidean' metrics
are used.
p: float, optional (default=1.0)
The p-norm to apply for if metric='minkowski'
dense: boolean, optional (default=True)
If True, returns math:\gamma as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's coo_matrix
format. Only used if log is set to True. Due to implementation details,
this function runs faster when dense is set to False.
log: boolean, optional (default=False)
If True, returns a dictionary containing the transportation matrix.
Otherwise returns only the loss.

Returns
-------
loss: float
Cost associated to the optimal transportation
log: dict
If input log is True, a dictionary containing the Optimal transportation
matrix for the given parameters

Examples
--------

Simple example with obvious solution. The function emd2_1d accepts lists and
performs automatic conversion to numpy arrays

>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> x_a = [2., 0.]
>>> x_b = [0., 3.]
>>> ot.emd2_1d(x_a, x_b, a, b)
0.5
>>> ot.emd2_1d(x_a, x_b)
0.5

References
----------

.. [1]  Peyré, G., & Cuturi, M. (2017). "Computational Optimal
Transport", 2018.

--------
ot.lp.emd2 : EMD for multidimensional distributions
ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
"""
# If we do not return G (log==False), then we should not to cast it to dense
G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
dense=dense and log, log=True)
cost = log_emd['cost']
if log:
log_emd = {'G': G}
return cost, log_emd
return cost

def roll_cols(M, shifts):
r"""
Utils functions which allow to shift the order of each row of a 2d matrix

Parameters
----------
M : (nr, nc) ndarray
Matrix to shift
shifts: int or (nr,) ndarray

Returns
-------
Shifted array

Examples
--------
>>> M = np.array([[1,2,3],[4,5,6],[7,8,9]])
>>> roll_cols(M, 2)
array([[2, 3, 1],
[5, 6, 4],
[8, 9, 7]])
>>> roll_cols(M, np.array([[1],[2],[1]]))
array([[3, 1, 2],
[5, 6, 4],
[9, 7, 8]])

References
----------
https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch
"""
nx = get_backend(M)

n_rows, n_cols = M.shape

arange1 = nx.tile(nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1))
arange2 = (arange1 - shifts) % n_cols

return nx.take_along_axis(M, arange2, 1)

def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1])

Parameters
----------
theta: array-like, shape (n_batch, n)
Cuts on the circle
u_values: array-like, shape (n_batch, n)
locations of the first empirical distribution
v_values: array-like, shape (n_batch, n)
locations of the second empirical distribution
u_cdf: array-like, shape (n_batch, n)
cdf of the first empirical distribution
v_cdf: array-like, shape (n_batch, n)
cdf of the second empirical distribution
p: float, optional = 2
Power p used for computing the Wasserstein distance

Returns
-------
dCp: array-like, shape (n_batch, 1)
The batched right derivative
dCm: array-like, shape (n_batch, 1)
The batched left derivative

References
---------
.. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
"""
nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)

v_values = nx.copy(v_values)

n = u_values.shape[-1]
m_batch, m = v_values.shape

v_cdf_theta = v_cdf - (theta - nx.floor(theta))

v_cdf_theta2 = nx.copy(v_cdf_theta)
shift = (-nx.argmin(v_cdf_theta2, axis=-1))

v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)

if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
# and avoid a warning related to non-contiguous arrays
u_cdf = u_cdf.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()

# quantiles of F_u evaluated in F_v^\theta
u_index = nx.searchsorted(u_cdf, v_cdf_theta)
u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1)

# Deal with 1
u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1)
u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1)

if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
# and avoid a warning related to non-contiguous arrays
u_cdfm = u_cdfm.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()

u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right")
u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1)

dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p)
- nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1)

dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p)
- nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1)

return dCp.reshape(-1, 1), dCm.reshape(-1, 1)

def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p):
r""" Computes the the cost (Equation (6.2) of [1])

Parameters
----------
theta: array-like, shape (n_batch, n)
Cuts on the circle
u_values: array-like, shape (n_batch, n)
locations of the first empirical distribution
v_values: array-like, shape (n_batch, n)
locations of the second empirical distribution
u_cdf: array-like, shape (n_batch, n)
cdf of the first empirical distribution
v_cdf: array-like, shape (n_batch, n)
cdf of the second empirical distribution
p: float, optional = 2
Power p used for computing the Wasserstein distance

Returns
-------
ot_cost: array-like, shape (n_batch,)
OT cost evaluated at theta

References
---------
.. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
"""
nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)

v_values = nx.copy(v_values)

m_batch, m = v_values.shape
n_batch, n = u_values.shape

v_cdf_theta = v_cdf - (theta - nx.floor(theta))

# Put negative values at the end
v_cdf_theta2 = nx.copy(v_cdf_theta)
shift = (-nx.argmin(v_cdf_theta2, axis=-1))

v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)

# Compute absciss
cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1)

if nx.__name__ == 'torch':
# this is to ensure the best performance for torch searchsorted
# and avoid a warninng related to non-contiguous arrays
u_cdf = u_cdf.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
cdf_axis = cdf_axis.contiguous()

# Compute icdf
u_index = nx.searchsorted(u_cdf, cdf_axis)
u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1)

v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1)
v_index = nx.searchsorted(v_cdf_theta, cdf_axis)
v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1)

if p == 1:
ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1)
else:
ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1)

return ot_cost

[docs]
def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True,
log=False):
r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44].
Samples need to be in :math:S^1\cong [0,1[. If they are on :math:\mathbb{R},
takes the value modulo 1.
If the values are on :math:S^1\subset\mathbb{R}^2, it is required to first find the coordinates
using e.g. the atan2 function.

.. math::
W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q)  - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q

where:

- :math:F_u and :math:F_v are respectively the cdfs of :math:u and :math:v

For values :math:x=(x_1,x_2)\in S^1, it is required to first get their coordinates with

.. math::
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}

using e.g. ot.utils.get_coordinate_circle(x)

The function runs on backend but tensorflow and jax are not supported.

Parameters
----------
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
v_values : ndarray, shape (n, ...)
samples in the target domain (coordinates on [0,1[)
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
v_weights : ndarray, shape (n, ...), optional
samples weights in the target domain
p : float, optional (default=1)
Power p used for computing the Wasserstein distance
Lm : int, optional
Lower bound dC
Lp : int, optional
Upper bound dC
tm: float, optional
Lower bound theta
tp: float, optional
Upper bound theta
eps: float, optional
Stopping condition
require_sort: bool, optional
If True, sort the values.
log: bool, optional
If True, returns also the optimal theta

Returns
-------
loss: float
Cost associated to the optimal transportation
log: dict, optional
log dictionary returned only if log==True in parameters

Examples
--------
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> binary_search_circle(u.T, v.T, p=1)
array([0.1])

References
----------
.. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
.. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html
"""
assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)

if u_weights is not None and v_weights is not None:
nx = get_backend(u_values, v_values, u_weights, v_weights)
else:
nx = get_backend(u_values, v_values)

n = u_values.shape[0]
m = v_values.shape[0]

if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))
if len(v_values.shape) == 1:
v_values = nx.reshape(v_values, (m, 1))

if u_values.shape[1] != v_values.shape[1]:
raise ValueError(
"u and v must have the same number of batches {} and {} respectively given".format(u_values.shape[1],
v_values.shape[1]))

u_values = u_values % 1
v_values = v_values % 1

if u_weights is None:
u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)

if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_values = nx.take_along_axis(u_values, u_sorter, 0)

v_sorter = nx.argsort(v_values, 0)
v_values = nx.take_along_axis(v_values, v_sorter, 0)

u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
v_weights = nx.take_along_axis(v_weights, v_sorter, 0)

u_cdf = nx.cumsum(u_weights, 0).T
v_cdf = nx.cumsum(v_weights, 0).T

u_values = u_values.T
v_values = v_values.T

L = max(Lm, Lp)

tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
tm = nx.tile(tm, (1, m))
tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
tp = nx.tile(tp, (1, m))
tc = (tm + tp) / 2

done = nx.zeros((u_values.shape[0], m))

cpt = 0
while nx.any(1 - done):
cpt += 1

dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
done = ((dCp * dCm) <= 0) * 1

mask = ((tp - tm) < eps / L) * (1 - done)

# can probably be improved by computing only relevant values
dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p)
dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p)
Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)
Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1)

tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0]
done[nx.prod(mask, axis=-1) > 0] = 1
elif nx.any(1 - done):
tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0]
tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2

w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p)

if log:
return w, {"optimal_theta": tc[:, 0]}
return w

def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True):
r"""Computes the 1-Wasserstein distance on the circle using the level median [45].
Samples need to be in :math:S^1\cong [0,1[. If they are on :math:\mathbb{R},
takes the value modulo 1.
If the values are on :math:S^1\subset\mathbb{R}^2, first find the coordinates
using e.g. the atan2 function.
The function runs on backend but tensorflow and jax are not supported.

.. math::
W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t

Parameters
----------
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
v_values : ndarray, shape (n, ...)
samples in the target domain (coordinates on [0,1[)
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
v_weights : ndarray, shape (n, ...), optional
samples weights in the target domain
require_sort: bool, optional
If True, sort the values.

Returns
-------
loss: float
Cost associated to the optimal transportation

Examples
--------
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> wasserstein1_circle(u.T, v.T)
array([0.1])

References
----------
.. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
.. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
"""

if u_weights is not None and v_weights is not None:
nx = get_backend(u_values, v_values, u_weights, v_weights)
else:
nx = get_backend(u_values, v_values)

n = u_values.shape[0]
m = v_values.shape[0]

if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))
if len(v_values.shape) == 1:
v_values = nx.reshape(v_values, (m, 1))

if u_values.shape[1] != v_values.shape[1]:
raise ValueError(
"u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1],
v_values.shape[1]))

u_values = u_values % 1
v_values = v_values % 1

if u_weights is None:
u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)

if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_values = nx.take_along_axis(u_values, u_sorter, 0)

v_sorter = nx.argsort(v_values, 0)
v_values = nx.take_along_axis(v_values, v_sorter, 0)

u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
v_weights = nx.take_along_axis(v_weights, v_sorter, 0)

# Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0)

cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0)
cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0)

delta = values_sorted[1:, ...] - values_sorted[:-1, ...]
weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0)

sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5
sum_weights[sum_weights < 0] = np.inf
inds = nx.argmin(sum_weights, axis=0)

levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0)

return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)

[docs]
def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True):
r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or
the binary search algorithm proposed in [44] otherwise.
Samples need to be in :math:S^1\cong [0,1[. If they are on :math:\mathbb{R},
takes the value modulo 1.
If the values are on :math:S^1\subset\mathbb{R}^2, it requires to first find the coordinates
using e.g. the atan2 function.

General loss returned:

.. math::
OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q)  - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q

For p=1, [45]

.. math::
W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t

For values :math:x=(x_1,x_2)\in S^1, it is required to first get their coordinates with

.. math::
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}

using e.g. ot.utils.get_coordinate_circle(x)

The function runs on backend but tensorflow and jax are not supported.

Parameters
----------
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
v_values : ndarray, shape (n, ...)
samples in the target domain (coordinates on [0,1[)
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
v_weights : ndarray, shape (n, ...), optional
samples weights in the target domain
p : float, optional (default=1)
Power p used for computing the Wasserstein distance
Lm : int, optional
Lower bound dC. For p>1.
Lp : int, optional
Upper bound dC. For p>1.
tm: float, optional
Lower bound theta. For p>1.
tp: float, optional
Upper bound theta. For p>1.
eps: float, optional
Stopping condition. For p>1.
require_sort: bool, optional
If True, sort the values.

Returns
-------
loss: float
Cost associated to the optimal transportation

Examples
--------
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> wasserstein_circle(u.T, v.T)
array([0.1])

References
----------
.. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
.. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
"""
assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)

if p == 1:
return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort)

return binary_search_circle(u_values, v_values, u_weights, v_weights,
p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps,
require_sort=require_sort)

[docs]
def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:S^1
Samples need to be in :math:S^1\cong [0,1[. If they are on :math:\mathbb{R},
takes the value modulo 1.
If the values are on :math:S^1\subset\mathbb{R}^2, it is required to first find the coordinates
using e.g. the atan2 function.

.. math::
W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12}

where:

- :math:\nu=\mathrm{Unif}(S^1) and :math:\mu_n  = \sum_{i=1}^n \alpha_i \delta_{x_i}

For values :math:x=(x_1,x_2)\in S^1, it is required to first get their coordinates with

.. math::
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},

using e.g. ot.utils.get_coordinate_circle(x)

Parameters
----------
u_values: ndarray, shape (n, ...)
Samples
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain

Returns
-------
loss: float
Cost associated to the optimal transportation

Examples
--------
>>> x0 = np.array([[0], [0.2], [0.4]])
>>> semidiscrete_wasserstein2_unif_circle(x0)
array([0.02111111])

References
----------
.. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
"""

if u_weights is not None:
nx = get_backend(u_values, u_weights)
else:
nx = get_backend(u_values)

n = u_values.shape[0]

u_values = u_values % 1

if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))

if u_weights is None:
u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)

u_values = nx.sort(u_values, 0)
u_cdf = nx.cumsum(u_weights, 0)
u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])

cpt1 = nx.sum(u_weights * u_values**2, axis=0)
u_mean = nx.sum(u_weights * u_values, axis=0)

ns = 1 - u_weights - 2 * u_cdf[:-1]
cpt2 = nx.sum(u_values * u_weights * ns, axis=0)

return cpt1 - u_mean**2 + cpt2 + 1 / 12