# -*- coding: utf-8 -*-
"""
1D Unbalanced OT solvers
"""
# Author: Clément Bonet <clement.bonet.mapp@polytechnique.edu>
#
# License: MIT License
from ..backend import get_backend
from ..utils import get_parameter_pair
from ..lp.solver_1d import emd_1d_dual_backprop, wasserstein_1d
def rescale_potentials(f, g, a, b, rho1, rho2, nx):
r"""
Find the optimal :math: `\lambda` in the translation invariant dual of UOT
with KL regularization and returns it, see Proposition 2 in :ref:`[73] <references-uot>`.
Parameters
----------
f: array-like, shape (n, ...)
first dual potential
g: array-like, shape (m, ...)
second dual potential
a: array-like, shape (n, ...)
weights of the first empirical distribution
b: array-like, shape (m, ...)
weights of the second empirical distribution
rho1: float
Marginal relaxation term for the first marginal
rho2: float
Marginal relaxation term for the second marginal
nx: module
backend module
Returns
-------
transl: array-like, shape (...)
optimal translation
.. _references-uot:
References
----------
.. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
"""
if rho1 == float("inf") and rho2 == float("inf"):
return nx.zeros(shape=nx.sum(f, axis=0).shape, type_as=f)
elif rho1 == float("inf"):
tau = rho2
denom = nx.logsumexp(-g / rho2 + nx.log(b), axis=0)
num = nx.log(nx.sum(a, axis=0))
elif rho2 == float("inf"):
tau = rho1
num = nx.logsumexp(-f / rho1 + nx.log(a), axis=0)
denom = nx.log(nx.sum(b, axis=0))
else:
tau = (rho1 * rho2) / (rho1 + rho2)
num = nx.logsumexp(-f / rho1 + nx.log(a), axis=0)
denom = nx.logsumexp(-g / rho2 + nx.log(b), axis=0)
transl = tau * (num - denom)
return transl
def get_reweighted_marginal_uot(
f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx
):
r"""
One step of the FW algorithm for the 1D UOT problem with KL regularization.
This function computes the reweighted marginals given the current dual potentials.
It returns the current potentials, and the reweighted marginals (normalized by the mass so that they sum to 1).
Parameters
----------
f: array-like, shape (n, ...)
first dual potential
g: array-like, shape (m, ...)
second dual potential
u_weights_sorted: array-like, shape (n, ...)
weights of the first empirical distribution, sorted w.r.t. the support
v_weights_sorted: array-like, shape (m, ...)
weights of the second empirical distribution, sorted w.r.t. the support
reg_m1: float
Marginal relaxation term for the first marginal
reg_m2: float
Marginal relaxation term for the second marginal
nx: module
backend module
Returns
-------
f: array-like, shape (n, ...)
first dual potential
g: array-like, shape (m, ...)
second dual potential
u_rescaled: array-like, shape (n, ...)
reweighted first marginal, normalized by the mass
v_rescaled: array-like, shape (m, ...)
reweighted second marginal, normalized by the mass
full_mass: array-like, shape (...)
mass of the reweighted marginals
"""
transl = rescale_potentials(
f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx
)
f = f + transl[None]
g = g - transl[None]
if reg_m1 != float("inf"):
u_reweighted = u_weights_sorted * nx.exp(-f / reg_m1)
else:
u_reweighted = u_weights_sorted
if reg_m2 != float("inf"):
v_reweighted = v_weights_sorted * nx.exp(-g / reg_m2)
else:
v_reweighted = v_weights_sorted
full_mass = nx.sum(u_reweighted, axis=0)
# Normalize weights
u_rescaled = u_reweighted / nx.sum(u_reweighted, axis=0, keepdims=True)
v_rescaled = v_reweighted / nx.sum(v_reweighted, axis=0, keepdims=True)
return f, g, u_rescaled, v_rescaled, full_mass
[docs]
def uot_1d(
u_values,
v_values,
reg_m,
u_weights=None,
v_weights=None,
p=2,
require_sort=True,
numItermax=10,
returnCost="linear",
log=False,
):
r"""
Solves the 1D unbalanced OT problem with KL regularization.
The function implements the Frank-Wolfe algorithm to solve the dual problem,
as proposed in :ref:`[73] <references-uot>`.
The unbalanced OT problem reads
.. math::
\mathrm{UOT}_p^p(\mu,\nu) = \min_{\gamma \in \mathcal{M}_{+}(\mathbb{R}\times\mathbb{R})} W_p^p(\pi^1_\#\gamma,\pi^2_\#\gamma) + \mathrm{reg_{m}}_1 \mathrm{KL}(\pi^1_\#\gamma|\mu) + \mathrm{reg_{m}}_2 \mathrm{KL}(\pi^2_\#\gamma|\nu).
.. warning:: This function only works in pytorch or jax as it uses autodifferentiation to compute the potentials. It is not maintained in jax.
Parameters
----------
u_values: array-like, shape (n, ...)
locations of the first empirical distribution
v_values: array-like, shape (m, ...)
locations of the second empirical distribution
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term.
If `reg_m` is a scalar or an indexable object of length 1,
then the same `reg_m` is applied to both marginal relaxations.
The balanced OT can be recovered using `reg_m=float("inf")`.
For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
If `reg_m` is an array, it must have the same backend as input arrays `(u_values, v_values)`.
u_weights: array-like, shape (n, ...), optional
weights of the first empirical distribution, if None then uniform weights are used
v_weights: array-like, shape (m, ...), optional
weights of the second empirical distribution, if None then uniform weights are used
p: int, optional
order of the ground metric used, should be at least 1, default is 2
require_sort: bool, optional
sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to
the function, default is True
numItermax: int, optional
returnCost: string, optional (default = "linear")
If `returnCost` = "linear", then return the linear part of the unbalanced OT loss.
If `returnCost` = "total", then return the total unbalanced OT loss.
log: bool, optional
Returns
-------
u_reweighted: array-like shape (n, ...)
First marginal reweighted
v_reweighted: array-like shape (m, ...)
Second marginal reweighted
loss: float/array-like, shape (...)
The batched 1D UOT
log: dict, optional
If `log` is True, then returns a dictionary containing the dual potentials, the total cost and the linear cost.
.. _references-uot:
References
---------
.. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
"""
nx = get_backend(u_values, v_values, u_weights, v_weights)
assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax"
reg_m1, reg_m2 = get_parameter_pair(reg_m)
n = u_values.shape[0]
m = v_values.shape[0]
# Init weights or broadcast if necessary
if u_weights is None:
u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values)
elif u_weights.ndim != u_values.ndim:
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
if v_weights is None:
v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values)
elif v_weights.ndim != v_values.ndim:
v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1)
# Sort w.r.t. support if not already done
if require_sort:
u_sorter = nx.argsort(u_values, 0)
u_rev_sorter = nx.argsort(u_sorter, 0)
u_values_sorted = nx.take_along_axis(u_values, u_sorter, 0)
v_sorter = nx.argsort(v_values, 0)
v_rev_sorter = nx.argsort(v_sorter, 0)
v_values_sorted = nx.take_along_axis(v_values, v_sorter, 0)
u_weights_sorted = nx.take_along_axis(u_weights, u_sorter, 0)
v_weights_sorted = nx.take_along_axis(v_weights, v_sorter, 0)
f = nx.zeros(u_weights.shape, type_as=u_weights)
fd = nx.zeros(u_weights.shape, type_as=u_weights)
g = nx.zeros(v_weights.shape, type_as=v_weights)
gd = nx.zeros(v_weights.shape, type_as=v_weights)
for i in range(numItermax):
# FW steps
f, g, u_rescaled, v_rescaled, _ = get_reweighted_marginal_uot(
f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx
)
fd, gd, loss = emd_1d_dual_backprop(
u_values_sorted,
v_values_sorted,
u_weights=u_rescaled,
v_weights=v_rescaled,
p=p,
require_sort=False,
)
t = 2.0 / (2.0 + i)
f = f + t * (fd - f)
g = g + t * (gd - g)
f, g, u_rescaled, v_rescaled, full_mass = get_reweighted_marginal_uot(
f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx
)
loss = wasserstein_1d(
u_values_sorted,
v_values_sorted,
u_rescaled,
v_rescaled,
p=p,
require_sort=False,
)
if require_sort:
f = nx.take_along_axis(f, u_rev_sorter, 0)
g = nx.take_along_axis(g, v_rev_sorter, 0)
u_reweighted = nx.take_along_axis(u_rescaled, u_rev_sorter, 0) * full_mass
v_reweighted = nx.take_along_axis(v_rescaled, v_rev_sorter, 0) * full_mass
# rescale OT loss
linear_loss = loss * full_mass
if reg_m1 == float("inf") and reg_m2 == float("inf"):
uot_loss = linear_loss
elif reg_m1 == float("inf"):
uot_loss = linear_loss + reg_m2 * nx.kl_div(v_reweighted, v_weights, mass=True)
elif reg_m2 == float("inf"):
uot_loss = linear_loss + reg_m1 * nx.kl_div(u_reweighted, u_weights, mass=True)
else:
uot_loss = (
linear_loss
+ reg_m1 * nx.kl_div(u_reweighted, u_weights, mass=True, axis=0)
+ reg_m2 * nx.kl_div(v_reweighted, v_weights, mass=True, axis=0)
)
if returnCost == "linear":
out_loss = linear_loss
elif returnCost == "total":
out_loss = uot_loss
if log:
dico = {"f": f, "g": g, "total_cost": uot_loss, "linear_cost": linear_loss}
return u_reweighted, v_reweighted, out_loss, dico
return u_reweighted, v_reweighted, out_loss