# -*- coding: utf-8 -*-
"""
Exact solvers for the 1D Wasserstein distance using cvxopt
"""
# Author: Clément Bonet <clement.bonet.mapp@polytechnique.edu>
#
# License: MIT License
import numpy as np
import warnings
from ..backend import get_backend
from .solver_1d import quantile_function
def roll_cols(M, shifts):
r"""
Utils functions which allow to shift the order of each row of a 2d matrix
Parameters
----------
M : ndarray, shape (nr, nc)
Matrix to shift
shifts: int or ndarray, shape (nr,)
Returns
-------
Shifted array
Examples
--------
>>> M = np.array([[1,2,3],[4,5,6],[7,8,9]])
>>> roll_cols(M, 2)
array([[2, 3, 1],
[5, 6, 4],
[8, 9, 7]])
>>> roll_cols(M, np.array([[1],[2],[1]]))
array([[3, 1, 2],
[5, 6, 4],
[9, 7, 8]])
References
----------
https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch
"""
nx = get_backend(M)
n_rows, n_cols = M.shape
arange1 = nx.tile(
nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1)
)
arange2 = (arange1 - shifts) % n_cols
return nx.take_along_axis(M, arange2, 1)
def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2):
r"""Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1])
Parameters
----------
theta: array-like, shape (n_batch, n)
Cuts on the circle
u_values: array-like, shape (n_batch, n)
locations of the first empirical distribution
v_values: array-like, shape (n_batch, n)
locations of the second empirical distribution
u_cdf: array-like, shape (n_batch, n)
cdf of the first empirical distribution
v_cdf: array-like, shape (n_batch, n)
cdf of the second empirical distribution
p: float, optional = 2
Power p used for computing the Wasserstein distance
Returns
-------
dCp: array-like, shape (n_batch, 1)
The batched right derivative
dCm: array-like, shape (n_batch, 1)
The batched left derivative
References
---------
.. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
"""
nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
v_values = nx.copy(v_values)
n = u_values.shape[-1]
m_batch, m = v_values.shape
v_cdf_theta = v_cdf - (theta - nx.floor(theta))
mask_p = v_cdf_theta >= 0
mask_n = v_cdf_theta < 0
v_values[mask_n] += nx.floor(theta)[mask_n] + 1
v_values[mask_p] += nx.floor(theta)[mask_p]
if nx.any(mask_n) and nx.any(mask_p):
v_cdf_theta[mask_n] += 1
v_cdf_theta2 = nx.copy(v_cdf_theta)
v_cdf_theta2[mask_n] = np.inf
shift = -nx.argmin(v_cdf_theta2, axis=-1)
v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
v_values = nx.concatenate(
[v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1
)
if nx.__name__ == "torch":
# this is to ensure the best performance for torch searchsorted
# and avoid a warning related to non-contiguous arrays
u_cdf = u_cdf.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
# quantiles of F_u evaluated in F_v^\theta
u_index = nx.searchsorted(u_cdf, v_cdf_theta)
u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1)
# Deal with 1
u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1)
u_valuesm = nx.concatenate(
[u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1
)
if nx.__name__ == "torch":
# this is to ensure the best performance for torch searchsorted
# and avoid a warning related to non-contiguous arrays
u_cdfm = u_cdfm.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right")
u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1)
dCp = nx.sum(
nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p)
- nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p),
axis=-1,
)
dCm = nx.sum(
nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p)
- nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p),
axis=-1,
)
return dCp.reshape(-1, 1), dCm.reshape(-1, 1)
def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p):
r"""Computes the the cost (Equation (6.2) of [1])
Parameters
----------
theta: array-like, shape (n_batch, n)
Cuts on the circle
u_values: array-like, shape (n_batch, n)
locations of the first empirical distribution
v_values: array-like, shape (n_batch, n)
locations of the second empirical distribution
u_cdf: array-like, shape (n_batch, n)
cdf of the first empirical distribution
v_cdf: array-like, shape (n_batch, n)
cdf of the second empirical distribution
p: float, optional = 2
Power p used for computing the Wasserstein distance
Returns
-------
ot_cost: array-like, shape (n_batch,)
OT cost evaluated at theta
References
---------
.. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
"""
nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf)
v_values = nx.copy(v_values)
m_batch, m = v_values.shape
n_batch, n = u_values.shape
v_cdf_theta = v_cdf - (theta - nx.floor(theta))
mask_p = v_cdf_theta >= 0
mask_n = v_cdf_theta < 0
v_values[mask_n] += nx.floor(theta)[mask_n] + 1
v_values[mask_p] += nx.floor(theta)[mask_p]
if nx.any(mask_n) and nx.any(mask_p):
v_cdf_theta[mask_n] += 1
# Put negative values at the end
v_cdf_theta2 = nx.copy(v_cdf_theta)
v_cdf_theta2[mask_n] = np.inf
shift = -nx.argmin(v_cdf_theta2, axis=-1)
v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1)))
v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1)))
v_values = nx.concatenate(
[v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1
)
# Compute absciss
cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1)
cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)])
delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1]
if nx.__name__ == "torch":
# this is to ensure the best performance for torch searchsorted
# and avoid a warning related to non-contiguous arrays
u_cdf = u_cdf.contiguous()
v_cdf_theta = v_cdf_theta.contiguous()
cdf_axis = cdf_axis.contiguous()
# Compute icdf
u_index = nx.searchsorted(u_cdf, cdf_axis)
u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1)
v_values = nx.concatenate(
[v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1
)
v_index = nx.searchsorted(v_cdf_theta, cdf_axis)
v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1)
if p == 1:
ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1)
else:
ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1)
return ot_cost
[docs]
def binary_search_circle(
u_values,
v_values,
u_weights=None,
v_weights=None,
p=1,
Lm=10,
Lp=10,
tm=-1,
tp=1,
eps=1e-6,
require_sort=True,
log=False,
):
r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44].
Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
takes the value modulo 1.
If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
using e.g. the atan2 function.
.. math::
W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
where:
- :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v`
For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
.. math::
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
using e.g. ot.utils.get_coordinate_circle(x)
The function runs on backend but tensorflow and jax are not supported.
Parameters
----------
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
v_values : ndarray, shape (n, ...)
samples in the target domain (coordinates on [0,1[)
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
v_weights : ndarray, shape (n, ...), optional
samples weights in the target domain
p : float, optional (default=1)
Power p used for computing the Wasserstein distance
Lm : int, optional
Lower bound dC
Lp : int, optional
Upper bound dC
tm: float, optional
Lower bound theta
tp: float, optional
Upper bound theta
eps: float, optional
Stopping condition
require_sort: bool, optional
If True, sort the values.
log: bool, optional
If True, returns also the optimal theta
Returns
-------
loss: float/array-like, shape (...)
Batched cost associated to the optimal transportation
log: dict, optional
log dictionary returned only if log==True in parameters
Examples
--------
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> binary_search_circle(u.T, v.T, p=1)
array([0.1])
References
----------
.. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
.. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html
"""
assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
nx = get_backend(u_values, v_values, u_weights, v_weights)
n = u_values.shape[0]
m = v_values.shape[0]
if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))
if len(v_values.shape) == 1:
v_values = nx.reshape(v_values, (m, 1))
if u_values.shape[1] != v_values.shape[1]:
raise ValueError(
"u and v must have the same number of batches {} and {} respectively given".format(
u_values.shape[1], v_values.shape[1]
)
)
u_values = u_values % 1
v_values = v_values % 1
if u_weights is None:
u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_values = nx.take_along_axis(u_values, u_sorter, 0)
v_sorter = nx.argsort(v_values, 0)
v_values = nx.take_along_axis(v_values, v_sorter, 0)
u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
u_cdf = nx.cumsum(u_weights, 0).T
v_cdf = nx.cumsum(v_weights, 0).T
u_values = u_values.T
v_values = v_values.T
L = max(Lm, Lp)
tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
tm = nx.tile(tm, (1, m))
tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1))
tp = nx.tile(tp, (1, m))
tc = (tm + tp) / 2
done = nx.zeros((u_values.shape[0], m))
cpt = 0
while nx.any(1 - done):
cpt += 1
dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p)
done = ((dCp * dCm) <= 0) * 1
mask = ((tp - tm) < eps / L) * (1 - done)
if nx.any(mask):
# can probably be improved by computing only relevant values
dCptp, dCmtp = derivative_cost_on_circle(
tp, u_values, v_values, u_cdf, v_cdf, p
)
dCptm, dCmtm = derivative_cost_on_circle(
tm, u_values, v_values, u_cdf, v_cdf, p
)
Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(
-1, 1
)
Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(
-1, 1
)
# Avoid warning raised when dCptm - dCmtp == 0, for which
# tc is not updated as mask_end is False,
# see Issue #738
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001)
tc[mask_end > 0] = (
(Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp)
)[mask_end > 0]
done[nx.prod(mask, axis=-1) > 0] = 1
elif nx.any(1 - done):
tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0]
tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0]
tc[((1 - mask) * (1 - done)) > 0] = (
tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]
) / 2
w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p)
if log:
return w, {"optimal_theta": tc[:, 0]}
return w
def wasserstein1_circle(
u_values, v_values, u_weights=None, v_weights=None, require_sort=True
):
r"""Computes the 1-Wasserstein distance on the circle using the level median [45].
Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
takes the value modulo 1.
If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates
using e.g. the atan2 function.
The function runs on backend but tensorflow and jax are not supported.
.. math::
W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
Parameters
----------
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
v_values : ndarray, shape (n, ...)
samples in the target domain (coordinates on [0,1[)
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
v_weights : ndarray, shape (n, ...), optional
samples weights in the target domain
require_sort: bool, optional
If True, sort the values.
Returns
-------
loss: float/array-like, shape (...)
Batched cost associated to the optimal transportation
Examples
--------
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> wasserstein1_circle(u.T, v.T)
array([0.1])
References
----------
.. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
.. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
"""
nx = get_backend(u_values, v_values, u_weights, v_weights)
n = u_values.shape[0]
m = v_values.shape[0]
if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))
if len(v_values.shape) == 1:
v_values = nx.reshape(v_values, (m, 1))
if u_values.shape[1] != v_values.shape[1]:
raise ValueError(
"u and v must have the same number of batchs {} and {} respectively given".format(
u_values.shape[1], v_values.shape[1]
)
)
u_values = u_values % 1
v_values = v_values % 1
if u_weights is None:
u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_values = nx.take_along_axis(u_values, u_sorter, 0)
v_sorter = nx.argsort(v_values, 0)
v_values = nx.take_along_axis(v_values, v_sorter, 0)
u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
v_weights = nx.take_along_axis(v_weights, v_sorter, 0)
# Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/
values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0)
cdf_diff = nx.cumsum(
nx.take_along_axis(
nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0
),
0,
)
cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0)
values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1)
delta = values_sorted[1:, ...] - values_sorted[:-1, ...]
weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0)
sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5
sum_weights[sum_weights < 0] = np.inf
inds = nx.argmin(sum_weights, axis=0)
levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0)
return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0)
[docs]
def wasserstein_circle(
u_values,
v_values,
u_weights=None,
v_weights=None,
p=1,
Lm=10,
Lp=10,
tm=-1,
tp=1,
eps=1e-6,
require_sort=True,
):
r"""Computes the Wasserstein distance on the circle using either :ref:`[45] <references-wasserstein-circle>` for p=1 or
the binary search algorithm proposed in :ref:`[44] <references-wasserstein-circle>` otherwise.
Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
takes the value modulo 1.
If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates
using e.g. the atan2 function.
General loss returned:
.. math::
OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q
For p=1, [45]
.. math::
W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t
For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
.. math::
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
using e.g. ot.utils.get_coordinate_circle(x)
The function runs on backend but tensorflow and jax are not supported.
Parameters
----------
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
v_values : ndarray, shape (n, ...)
samples in the target domain (coordinates on [0,1[)
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
v_weights : ndarray, shape (n, ...), optional
samples weights in the target domain
p : float, optional (default=1)
Power p used for computing the Wasserstein distance
Lm : int, optional
Lower bound dC. For p>1.
Lp : int, optional
Upper bound dC. For p>1.
tm: float, optional
Lower bound theta. For p>1.
tp: float, optional
Upper bound theta. For p>1.
eps: float, optional
Stopping condition. For p>1.
require_sort: bool, optional
If True, sort the values.
Returns
-------
loss: float/array-like, shape (...)
Batched cost associated to the optimal transportation
Examples
--------
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> wasserstein_circle(u.T, v.T)
array([0.1])
.. _references-wasserstein-circle:
References
----------
.. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
.. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
"""
assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p)
return binary_search_circle(
u_values,
v_values,
u_weights,
v_weights,
p=p,
Lm=Lm,
Lp=Lp,
tm=tm,
tp=tp,
eps=eps,
require_sort=require_sort,
)
[docs]
def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1`
Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
takes the value modulo 1.
If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
using e.g. the atan2 function.
.. math::
W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12}
where:
- :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}`
For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
.. math::
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
using e.g. ot.utils.get_coordinate_circle(x).
Parameters
----------
u_values : ndarray, shape (n, ...)
Samples
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
Returns
-------
loss: float/array-like, shape (...)
Batched cost associated to the optimal transportation
Examples
--------
>>> x0 = np.array([[0], [0.2], [0.4]])
>>> semidiscrete_wasserstein2_unif_circle(x0)
array([0.02111111])
References
----------
.. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
"""
nx = get_backend(u_values, u_weights)
n = u_values.shape[0]
u_values = u_values % 1
if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))
if u_weights is None:
u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
u_values = nx.sort(u_values, 0)
u_cdf = nx.cumsum(u_weights, 0)
u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])
cpt1 = nx.sum(u_weights * u_values**2, axis=0)
u_mean = nx.sum(u_weights * u_values, axis=0)
ns = 1 - u_weights - 2 * u_cdf[:-1]
cpt2 = nx.sum(u_values * u_weights * ns, axis=0)
return cpt1 - u_mean**2 + cpt2 + 1 / 12
def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True):
r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference
:math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`.
For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[78] <references-lcot>`)
.. math``
\hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x.
Parameters
----------
x : ndary, shape (m,)
Points in [0,1[ where to evaluate the embedding
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
Returns
-------
embedding: ndarray of shape (m, ...)
Embedding evaluated at :math:`x`
.. _references-lcot:
References
----------
.. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations.
"""
nx = get_backend(u_values, u_weights)
n = u_values.shape[0]
u_values = u_values % 1
if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))
if u_weights is None:
u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_values = nx.take_along_axis(u_values, u_sorter, 0)
u_weights = nx.take_along_axis(u_weights, u_sorter, 0)
u_cdf = nx.cumsum(u_weights, 0)
u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)])
q_s = (
x[:, None] - nx.sum(u_values * u_weights, axis=0)[None] + 0.5
) # shape (m, ...)
u_quantiles = quantile_function(q_s % 1, u_cdf, u_values)
return (u_quantiles - x[:, None]) % 1
[docs]
def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None):
r"""Computes the Linear Circular Optimal Transport distance from :ref:`[78] <references-lcot>` using :math:`\eta=\mathrm{Unif}(S^1)`
as reference measure.
Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
takes the value modulo 1.
If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
using e.g. the atan2 function.
General loss returned:
.. math::
\mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t
where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`,
and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`.
Parameters
----------
u_values : ndarray, shape (n, ...)
samples in the source domain (coordinates on [0,1[)
v_values : ndarray, shape (n, ...), optional
samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution
u_weights : ndarray, shape (n, ...), optional
samples weights in the source domain
v_weights : ndarray, shape (n, ...), optional
samples weights in the target domain
Returns
-------
loss: float/array-like, shape (...)
Batched cost associated to the linear optimal transportation
Examples
--------
>>> u = np.array([[0.2,0.5,0.8]])%1
>>> v = np.array([[0.4,0.5,0.7]])%1
>>> linear_circular_ot(u.T, v.T)
array([0.0127])
.. _references-lcot:
References
----------
.. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations.
"""
nx = get_backend(u_values, u_weights)
n = u_values.shape[0]
u_values = u_values % 1
if len(u_values.shape) == 1:
u_values = nx.reshape(u_values, (n, 1))
if u_weights is None:
u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1]
emb_u = linear_circular_embedding(unif_s1, u_values, u_weights)
if v_values is None:
dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u))
return nx.mean(dist_u**2, axis=0)
else:
m = v_values.shape[0]
if len(v_values.shape) == 1:
v_values = nx.reshape(v_values, (m, 1))
if u_values.shape[1] != v_values.shape[1]:
raise ValueError(
"u and v must have the same number of batchs {} and {} respectively given".format(
u_values.shape[1], v_values.shape[1]
)
)
emb_v = linear_circular_embedding(unif_s1, v_values, v_weights)
dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v))
return nx.mean(dist_uv**2, axis=0)