Source code for ot.unbalanced._sliced

# -*- coding: utf-8 -*-
"""
Sliced Unbalanced OT solvers
"""

# Author: Clément Bonet <clement.bonet.mapp@polytechnique.edu>
#
# License: MIT License

from ..backend import get_backend
from ..utils import get_parameter_pair
from ..sliced import get_random_projections
from ._solver_1d import rescale_potentials, uot_1d
from ..lp.solver_1d import emd_1d_dual_backprop, wasserstein_1d


[docs] def sliced_unbalanced_ot( X_s, X_t, reg_m, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, numItermax=10, log=False, ): r""" Compute the Sliced Unbalanced Optimal Transport (SUOT) between two empirical distributions. The 1D UOT problem is computed with KL regularization and solved with a Frank-Wolfe algorithm, see :ref:`[82] <references-suot>`. The Sliced Unbalanced Optimal Transport (SUOT) is defined as .. math:: \mathrm{SUOT}_p^p(\mu, \nu) = \int_{S^{d-1}} \mathrm{UOT}_p^p(P^\theta_\#\mu, P^\theta_\#\nu)\ \mathrm{d}\lambda(\theta) with :math:`P^\theta(x)=\langle x,\theta\rangle` and :math:`\lambda` the uniform distribution on the unit sphere. .. warning:: This function only works in pytorch or jax as it uses autodifferentiation to compute the 1D UOT problems. It is not maintained in jax. Parameters ---------- X_s : ndarray, shape (n_samples_a, dim) samples in the source domain X_t : ndarray, shape (n_samples_b, dim) samples in the target domain reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. If `reg_m` is a scalar or an indexable object of length 1, then the same `reg_m` is applied to both marginal relaxations. The balanced OT can be recovered using `reg_m=float("inf")`. For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If `reg_m` is an array, it must have the same backend as input arrays `(X_s, X_t)`. a : ndarray, shape (n_samples_a,), optional samples weights in the source domain b : ndarray, shape (n_samples_b,), optional samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation p: float, optional, by default =2 Power p used for computing the sliced Wasserstein projections: shape (dim, n_projections), optional Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional Seed used for random number generator numItermax: int, optional log: bool, optional if True, returns the projections used and their associated UOTs and reweighted marginals. Returns ------- loss: float/array-like, shape (...) SUOT log: dict, optional If `log` is True, then returns a dictionary containing the projection directions used, the projected UOTs, and reweighted marginals on each slices. .. _references-suot: References ---------- .. [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research. See Also -------- ot.unbalanced.uot_1d: 1D OT problem ot.unbalanced.unbalanced_sliced_ot: Unbalanced SOT problem """ nx = get_backend(X_s, X_t, a, b, projections) assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" n = X_s.shape[0] m = X_t.shape[0] if X_s.shape[1] != X_t.shape[1]: raise ValueError( "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( X_s.shape[1], X_t.shape[1] ) ) if a is None: a = nx.full(n, 1 / n, type_as=X_s) if b is None: b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] if projections is None: projections = get_random_projections( d, n_projections, seed, backend=nx, type_as=X_s ) else: n_projections = projections.shape[1] X_s_projections = nx.dot(X_s, projections) # shape (n, n_projs) X_t_projections = nx.dot(X_t, projections) # Compute UOT on each slice a_reweighted, b_reweighted, projected_uot = uot_1d( X_s_projections, X_t_projections, reg_m, a, b, p, require_sort=True, numItermax=numItermax, returnCost="total", ) res = nx.mean(projected_uot) if log: dico = { "projections": projections, "projected_uots": projected_uot, "a_reweighted": a_reweighted, "b_reweighted": b_reweighted, } return res, dico return res
def get_reweighted_marginals_usot( f, g, a, b, reg_m1, reg_m2, X_s_sorter, X_t_sorter, nx ): r""" One step of the FW algorithm for the Unbalanced Sliced OT problem, see Algorithm 1 and 3 in :ref:`[82] <references-uot>`. This function computes the reweighted marginals given the current potentials and the translation term. It returns the current potentials, and the reweighted marginals (normalized by the mass, so that they sum to 1). Parameters ---------- f: array-like shape (n, ...) Current potential on the source samples g: array-like shape (m, ...) Current potential on the target samples a: array-like shape (n, ...) Current weights on the source samples b: array-like shape (m, ...) Current weights on the target samples reg_m1: float Marginal relaxation term for the source distribution reg_m2: float Marginal relaxation term for the target distribution X_s_sorter: array-like shape (n_projs, n) Sorter for the projected source samples X_t_sorter: array-like shape (n_projs, m) Sorter for the projected target samples nx: module backend module Returns ------- f: array-like shape (n, ...) Current potential on the source samples g: array-like shape (m, ...) Current potential on the target samples a_reweighted: array-like shape (n, ...) Reweighted weights on the source samples (normalized by the mass) b_reweighted: array-like shape (m, ...) Reweighted weights on the target samples (normalized by the mass) full_mass: array-like shape (...) Mass of the reweighted measures .. _references-uot: References ---------- [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research. """ # translate potentials transl = rescale_potentials(f, g, a, b, reg_m1, reg_m2, nx) f = f + transl g = g - transl # update measures if reg_m1 != float("inf"): a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] else: a_reweighted = a[..., X_s_sorter] if reg_m2 != float("inf"): b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] else: b_reweighted = b[..., X_t_sorter] full_mass = nx.sum(a_reweighted, axis=1) # normalize the weights for compatibility with wasserstein_1d a_reweighted = a_reweighted / nx.sum(a_reweighted, axis=1, keepdims=True) b_reweighted = b_reweighted / nx.sum(b_reweighted, axis=1, keepdims=True) return f, g, a_reweighted, b_reweighted, full_mass
[docs] def unbalanced_sliced_ot( X_s, X_t, reg_m, a=None, b=None, n_projections=50, p=2, projections=None, seed=None, numItermax=10, log=False, ): r""" Compute the Unbalanced Sliced Optimal Transpot (USOT) with KL regularization between two empirical distributions. The Unbalanced SOT problem reads as .. math:: \mathrm{USOT}_p^p(\mu, \nu) = \inf_{\pi_1,\pi_2} \mathrm{SW}_p^p(\pi_1, \pi_2) + \mathrm{reg_{m}}_1 \mathrm{KL}(\pi_1||\mu) + \mathrm{reg_{m}}_2 \mathrm{KL}(\pi_2||\nu). The USOT problem is solved with a Frank-Wolfe algorithm as proposed in :ref:`[82] <references-usot>`. .. warning:: This function only works in pytorch or jax as it uses autodifferentiation to compute the 1D potentials. It is not maintained in jax. Parameters ---------- X_s : ndarray, shape (n_samples_a, dim) samples in the source domain X_t : ndarray, shape (n_samples_b, dim) samples in the target domain reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. If `reg_m` is a scalar or an indexable object of length 1, then the same `reg_m` is applied to both marginal relaxations. The balanced OT can be recovered using `reg_m=float("inf")`. For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If `reg_m` is an array, it must have the same backend as input arrays `(X_s, X_t)`. a : ndarray, shape (n_samples_a,), optional samples weights in the source domain b : ndarray, shape (n_samples_b,), optional samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation p: float, optional, by default =2 Power p used for computing the sliced Wasserstein projections: shape (dim, n_projections), optional Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional Seed used for random number generator numItermax: int, optional log: bool, optional if True, returns the sot loss, the projections used, their associated EMD and the full mass of the reweighted marginals. Returns ------- a_reweighted: array-like shape (n, ...) First marginal reweighted b_reweighted: array-like shape (m, ...) Second marginal reweighted loss: float/array-like, shape (...) USOT log: dict, optional If `log` is True, then returns a dictionary containing the projection directions used, the 1D OT losses, the SOT loss and the full mass of reweighted marginals. .. _references-usot: References ---------- .. [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research. See Also -------- ot.unbalanced.uot_1d: 1D OT problem ot.unbalanced.sliced_unbalanced_ot: SUOT problem """ nx = get_backend(X_s, X_t, a, b, projections) assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" reg_m1, reg_m2 = get_parameter_pair(reg_m) n = X_s.shape[0] m = X_t.shape[0] if X_s.shape[1] != X_t.shape[1]: raise ValueError( "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( X_s.shape[1], X_t.shape[1] ) ) if a is None: a = nx.full(n, 1 / n, type_as=X_s) if b is None: b = nx.full(m, 1 / m, type_as=X_s) d = X_s.shape[1] if projections is None: projections = get_random_projections( d, n_projections, seed, backend=nx, type_as=X_s ) else: n_projections = projections.shape[1] # Compute projections of the samples, and sort them for later use in the FW algorithm X_s_projections = nx.dot(X_s, projections).T # shape (n_projs, n) X_t_projections = nx.dot(X_t, projections).T X_s_sorter = nx.argsort(X_s_projections, -1) X_s_rev_sorter = nx.argsort(X_s_sorter, -1) X_s_sorted = nx.take_along_axis(X_s_projections, X_s_sorter, -1) X_t_sorter = nx.argsort(X_t_projections, -1) X_t_rev_sorter = nx.argsort(X_t_sorter, -1) X_t_sorted = nx.take_along_axis(X_t_projections, X_t_sorter, -1) # Initialize potentials - WARNING: They correspond to non-sorted samples f = nx.zeros(a.shape, type_as=a) g = nx.zeros(b.shape, type_as=b) for i in range(numItermax): f, g, a_reweighted, b_reweighted, _ = get_reweighted_marginals_usot( f, g, a, b, reg_m1, reg_m2, X_s_sorter, X_t_sorter, nx ) fd, gd, _ = emd_1d_dual_backprop( X_s_sorted.T, X_t_sorted.T, u_weights=a_reweighted.T, v_weights=b_reweighted.T, p=p, require_sort=False, ) fd, gd = fd.T, gd.T # default step for FW t = 2.0 / (2.0 + i) f = f + t * (nx.mean(nx.take_along_axis(fd, X_s_rev_sorter, 1), axis=0) - f) g = g + t * (nx.mean(nx.take_along_axis(gd, X_t_rev_sorter, 1), axis=0) - g) f, g, a_reweighted, b_reweighted, full_mass = get_reweighted_marginals_usot( f, g, a, b, reg_m1, reg_m2, X_s_sorter, X_t_sorter, nx ) ot_loss = wasserstein_1d( X_s_sorted.T, X_t_sorted.T, u_weights=a_reweighted.T, v_weights=b_reweighted.T, p=p, require_sort=False, ) sot_loss = nx.mean(ot_loss * full_mass) if reg_m1 != float("inf"): a_reweighted = a * nx.exp(-f / reg_m1) else: a_reweighted = a if reg_m2 != float("inf"): b_reweighted = b * nx.exp(-g / reg_m2) else: b_reweighted = b if reg_m1 == float("inf") and reg_m2 == float("inf"): uot_loss = sot_loss elif reg_m1 == float("inf"): uot_loss = sot_loss + reg_m2 * nx.kl_div(b_reweighted, b, mass=True) elif reg_m2 == float("inf"): uot_loss = sot_loss + reg_m1 * nx.kl_div(a_reweighted, a, mass=True) else: uot_loss = ( sot_loss + reg_m1 * nx.kl_div(a_reweighted, a, mass=True) + reg_m2 * nx.kl_div(b_reweighted, b, mass=True) ) if log: dico = { "projections": projections, "sot_loss": sot_loss, "1d_losses": ot_loss, "full_mass": full_mass, } return a_reweighted, b_reweighted, uot_loss, dico return a_reweighted, b_reweighted, uot_loss