Source code for ot.sliced._sliced_plans

# -*- coding: utf-8 -*-
"""
Sliced Wasserstein plans approximation solvers: min-pivot sliced (=sliced
Wasserstein Generalized Geodesic) and expected sliced.
"""

# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
#         Laetitia Chapel <laetitia.chapel@irisa.fr>
#
# License: MIT License

import warnings

from ..backend import get_backend
from ..utils import list_to_array, sparse_ot_dist, dist
from ._utils import get_random_projections
from ..lp import wasserstein_1d
from collections import namedtuple


[docs] def sliced_plans( X_s, X_t, a=None, b=None, metric="sqeuclidean", p=1, projections=None, n_projections=None, seed=None, batch_size=None, log=False, ): r""" Computes all the permutations that sort the projections of two `(ns, nt)` datasets `X_s` and `X_t` on the directions `projections`. Each permutation `perm[:, k]` is such that each :math:`X_s[i, :]` is matched to `X_t[perm[i, k], :]` when projected on `projections[k, :]`. Parameters ---------- X_s : array-like, shape (ns, d) The first set of vectors. X_t : array-like, shape (nt, d) The second set of vectors. a : ndarray of float64, shape (ns,), optional Source histogram (default is uniform weight) b : ndarray of float64, shape (nt,), optional Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') Metric to be used. Only works with either of the strings `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' projections: shape (dim, n_projections), optional Projection matrix (n_projections and seed are not used in this case) n_projections : int, optional The number of projection directions. Required if projections is None. seed : int, optional The seed for the random number generator for sampling projections, in case projections is None. Default is None. batch_size : int, optional If specified, compute the distance in batches of size `batch_size` to avoid memory issues for large datasets. Default is None (no batching). log : bool, optional If True, returns additional logging information. Default is False. Returns ------- plan: list of dictionaries List of the optimal transport plans as a list of dictionaries containing the rows, cols and data of the non-zero elements of the transportation matrix. costs : list of float The cost associated to each projection. log_dict : dict, optional A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. """ nx = get_backend(X_s, X_t, a, b, projections) assert X_s.ndim == 2, f"X_s must be a 2d array, got {X_s.ndim}d array instead" assert X_t.ndim == 2, f"X_t must be a 2d array, got {X_t.ndim}d array instead" assert metric in ("minkowski", "euclidean", "cityblock", "sqeuclidean"), ( "Sliced plans work only with metrics " + "from the following list: " + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" ) assert ( X_s.shape[1] == X_t.shape[1] ), f"X_s ({X_s.shape}) and X_t ({X_t.shape}) must have the same number of columns" if metric == "euclidean": p = 2 elif metric == "cityblock": p = 1 d = X_s.shape[1] n = X_s.shape[0] m = X_t.shape[0] if a is None: a = nx.ones(n) / n if b is None: b = nx.ones(m) / m is_perm = (n == m) and (a == a.sum() / n).all() and (b == a.sum() / n).all() if projections is None: # create projections (n_projections, d) assert ( n_projections is not None ), "n_projections must be specified if projections is None" projections = get_random_projections( d, n_projections, seed, backend=nx, type_as=X_s ) else: n_projections = projections.shape[1] # projections on the line X_s_projections = nx.dot(X_s, projections) X_t_projections = nx.dot(X_t, projections) if is_perm: # we compute maps (permutations) # sigma[:, i_projections] is a permutation sorting Xs_projections[:, i_projections] sigma = nx.argsort(X_s_projections, axis=0) # (n, n_projections) tau = nx.argsort(X_t_projections, axis=0) # (m, n_projections) costs = [ sparse_ot_dist( X_s, X_t, sigma[:, k], tau[:, k], metric=metric, p=p, batch_size=batch_size, ) for k in range(n_projections) ] PlanTuple = namedtuple("PlanTuple", ["data", "rows", "cols"]) plans = [ PlanTuple( data=a, rows=sigma[:, k], cols=tau[:, k], ) for k in range(n_projections) ] else: # we compute plans _, plans = wasserstein_1d( X_s_projections, X_t_projections, a, b, p, require_sort=True, return_plans="coo_tuple", ) costs = [ sparse_ot_dist( X_s, X_t, plans[k].rows, plans[k].cols, plans[k].data, metric=metric, p=p, batch_size=batch_size, ) for k in range(n_projections) ] if log: log_dict = { "X_projection": X_s_projections, "Y_projection": X_t_projections, "projections": projections, } return plans, nx.stack(costs), log_dict else: return plans, nx.stack(costs)
[docs] def min_sliced_transport_plan( X_s, X_t, a=None, b=None, projections=None, metric="sqeuclidean", p=2, n_projections=None, seed=None, batch_size=None, dense=True, log=False, ): r""" Computes the cost and permutation associated to the min-Pivot Sliced Discrepancy (introduced as min-SWGG in [85] and studied further in [86]). Given the supports `X_s` and `X_t` of two discrete uniform measures with `ns` and `nt` atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_projections` different projections of the measures on random directions, and retains the couplings that yields the lowest cost between `X_s` and `X_t` (compared in :math:`\mathbb{R}^d`). When `ns=nt`, it gives .. math:: \mathrm{min\text{-}PS}_p^p(X_s, X_t) \approx \min_{k \in [1, n_{\mathrm{proj}}]} \left( \frac{1}{n_s} \sum_{i=1}^{n_s} \|X_{s,i} - X_{t,\sigma_k(i)}\|_2^p \right), where :math:`\sigma_k` is a permutation such that ordering the projections on the axis `projections[k, :]` matches :math:`X_s[i, :]` to :math:`X_t[\sigma_k(i), :]`. .. note:: The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. .. warning:: Tensorflow and jax only returns dense plans, as they do not support well sparse matrices. Parameters ---------- X_s : array-like, shape (ns, d) The first set of vectors. X_t : array-like, shape (nt, d) The second set of vectors. a : ndarray of float64, shape (ns,), optional Source histogram (default is uniform weight) b : ndarray of float64, shape (nt,), optional Target histogram (default is uniform weight) projections: shape (dim, n_projections), optional Projection matrix (n_projections and seed are not used in this case). Default is None metric: str, optional (default='sqeuclidean') Metric to be used. Only works with either of the strings `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' n_projections : int, optional The number of projection directions. Required if projections is None. seed : int, optional The seed for the random number generator for sampling projections, in case projections is None. Default is None. batch_size : int, optional If specified, compute the distance in batches of size `batch_size` to avoid memory issues for large datasets. Default is None (no batching). dense: boolean, optional (default=True) If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. log : bool, optional If True, returns additional logging information. Default is False. Returns ------- plan : ndarray, shape (ns, nt) or coo_matrix if dense is False Optimal transportation matrix for the given parameters. cost : float The cost associated to the optimal permutation. log_dict : dict, optional A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. References ---------- .. [85] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. .. [86] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Transport Plans. arXiv preprint 2506.03661. Examples -------- >>> import ot >>> import numpy as np >>> x=np.array([[3.,3.], [1.,1.]]) >>> y=np.array([[2.,2.5], [3.,2.]]) >>> projections=np.array([[1, 0], [0, 1]]) >>> plan, cost = ot.min_sliced_transport_plan(x, y, projections=projections) >>> plan array([[0. , 0.5], [0.5, 0. ]]) >>> cost 2.125 """ nx = get_backend(X_s, X_t, a, b, projections) assert X_s.ndim == 2, f"X_s must be a 2d array, got {X_s.ndim}d array instead" assert X_t.ndim == 2, f"X_t must be a 2d array, got {X_t.ndim}d array instead" assert ( X_s.shape[1] == X_t.shape[1] ), f"X_s ({X_s.shape}) and X_t ({X_t.shape}) must have the same number of columns" if str(nx) in ["tf", "jax"] and not dense: dense = True warnings.warn("JAX and TF do not support sparse matrices, converting to dense") log_dict = {} plans, costs, log_dict_plans = sliced_plans( X_s, X_t, a, b, metric, p, projections, n_projections=n_projections, seed=seed, batch_size=batch_size, log=True, ) pos_min = nx.argmin(costs) cost = costs[pos_min] plan = plans[pos_min] if log: log_dict = { "projections": log_dict_plans["projections"], "costs": costs, "min_projection": log_dict_plans["projections"][:, pos_min], "X_min_projection": log_dict_plans["X_projection"][:, pos_min], "Y_min_projection": log_dict_plans["Y_projection"][:, pos_min], } # get the plan from the indices of the non-zero entries of the sparse plan plan = nx.coo_matrix( plan.data, plan.rows, plan.cols, shape=(X_s.shape[0], X_t.shape[0]), type_as=X_s, ) if dense: plan = nx.todense(plan) if log: return plan, cost, log_dict else: return plan, cost
[docs] def expected_sliced_plan( X_s, X_t, a=None, b=None, projections=None, metric="sqeuclidean", p=2, n_projections=None, beta=0.0, seed=None, dense=True, batch_size=None, log=False, ): r""" Computes the Expected Sliced cost and plan between two datasets `X_s` and `X_t` of shapes `(ns, d)` and `(nt, d)`. Given a set of `n_projections` projection directions, the expected sliced plan is obtained by averaging the `n_projections` 1d optimal transport plans between the projections of `X_s` and `X_t` on each direction. Expected Sliced was introduced in [87] and further studied in [86]. .. note:: The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. .. warning:: Tensorflow and jax only returns dense plans, as they do not support well sparse matrices. Parameters ---------- X_s : array-like, shape (ns, d) The first set of vectors. X_t : array-like, shape (nt, d) The second set of vectors. a : ndarray of float64, shape (ns,), optional Source histogram (default is uniform weight) b : ndarray of float64, shape (nt,), optional Target histogram (default is uniform weight) projections: shape (dim, n_projections), optional Projection matrix (n_projections and seed are not used in this case). Default is None metric: str, optional (default='sqeuclidean') Metric to be used. Only works with either of the strings `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=2) The p-norm to apply for if metric='minkowski' n_projections : int, optional The number of projection directions. Required if projections is None. seed : int, optional The seed for the random number generator for sampling projections, in case projections is None. Default is None. beta : float, optional Inverse-temperature parameter which weights each projection's contribution to the expected plan. Default is 0 (uniform weighting). dense: boolean, optional (default=True) If True, returns :math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. batch_size : int, optional If specified, compute the distance in batches of size `batch_size` to avoid memory issues for large datasets. Default is None (no batching). log : bool, optional If True, returns additional logging information. Default is False. Returns ------- plan : ndarray, shape (ns, nt) or coo_matrix if dense is False Optimal transportation matrix for the given parameters. cost : float The cost associated to the optimal permutation. log_dict : dict, optional A dictionary containing intermediate computations for logging purposes. Returned only if `log` is True. References ---------- .. [86] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Transport Plans. arXiv preprint 2506.03661. .. [87] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. Examples -------- >>> import ot >>> import numpy as np >>> x=np.array([[3.,3.], [1.,1.]]) >>> y=np.array([[2.,2.5], [3.,2.]]) >>> projections=np.array([[1, 0], [0, 1]]) >>> plan, cost = ot.expected_sliced_plan(x, y, projections=projections) >>> plan array([[0.25, 0.25], [0.25, 0.25]]) >>> cost 2.625 """ nx = get_backend(X_s, X_t, a, b, projections) assert X_s.ndim == 2, f"X_s must be a 2d array, got {X_s.ndim}d array instead" assert X_t.ndim == 2, f"X_t must be a 2d array, got {X_t.ndim}d array instead" assert ( X_s.shape[1] == X_t.shape[1] ), f"X_s ({X_s.shape}) and X_t ({X_t.shape}) must have the same number of columns" if str(nx) in ["tf", "jax"] and not dense: dense = True warnings.warn("JAX and TF do not support sparse matrices, converting to dense") n = X_s.shape[0] m = X_t.shape[0] log_dict = {} plans, costs, log_dict_plans = sliced_plans( X_s, X_t, a, b, metric, p, projections, n_projections=n_projections, seed=seed, log=True, batch_size=batch_size, ) if beta != 0.0: # computing the temperature weighting log_factors = -beta * list_to_array(costs) weights = nx.exp(log_factors - nx.logsumexp(log_factors)) cost = nx.sum(list_to_array(costs) * weights) else: # uniform weights if n_projections is None: n_projections = projections.shape[1] weights = nx.ones(n_projections) / n_projections weights_e = nx.concatenate([plans[i].data * weights[i] for i in range(len(plans))]) Xs_idx = nx.concatenate([plans[i].rows for i in range(len(plans))]) Xt_idx = nx.concatenate([plans[i].cols for i in range(len(plans))]) plan = nx.coo_matrix(weights_e, Xs_idx, Xt_idx, shape=(n, m), type_as=weights) if dense: plan = nx.todense(plan) if beta == 0.0: if dense: cost = nx.sum(plan * dist(X_s, X_t, metric=metric, p=p)) else: cost = plan.multiply(dist(X_s, X_t, metric=metric, p=p)).sum() if log: log_dict = { "projections": log_dict_plans["projections"], "costs": costs, "G": plans, "weights": weights, } return plan, cost, log_dict else: return plan, cost