Source code for ot.lowrank

"""
Low rank OT solvers
"""

# Author: Laurène David <laurene.david@ip-paris.fr>
#
# License: MIT License

import warnings
from .utils import unif, dist, get_lowrank_lazytensor
from .backend import get_backend
from .bregman import sinkhorn

# test if sklearn is installed for linux-minimal-deps
try:
    import sklearn.cluster

    sklearn_import = True
except ImportError:
    sklearn_import = False


def _init_lr_sinkhorn(X_s, X_t, a, b, rank, init, reg_init, random_state, nx=None):
    """
    Implementation of different initialization strategies for the low rank sinkhorn solver (Q ,R, g).
    This function is specific to lowrank_sinkhorn.

    Parameters
    ----------
    X_s : array-like, shape (n_samples_a, dim)
        samples in the source domain
    X_t : array-like, shape (n_samples_b, dim)
        samples in the target domain
    a : array-like, shape (n_samples_a,)
        samples weights in the source domain
    b : array-like, shape (n_samples_b,)
        samples weights in the target domain
    rank : int
        Nonnegative rank of the OT plan.
    init : str
        Initialization strategy for Q, R and g. 'random', 'trivial' or 'kmeans'
    reg_init : float, optional.
        Regularization term for a 'kmeans' init.
    random_state : int, optional.
        Random state for a "random" or 'kmeans' init strategy
    nx : optional, Default is None
        POT backend


    Returns
    ---------
    Q : array-like, shape (n_samples_a, r)
        Init for the first low-rank matrix decomposition of the OT plan (Q)
    R: array-like, shape (n_samples_b, r)
        Init for the second low-rank matrix decomposition of the OT plan (R)
    g : array-like, shape (r, )
        Init for the weight vector of the low-rank decomposition of the OT plan (g)


    References
    -----------
    .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
        "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.

    """

    if nx is None:
        nx = get_backend(X_s, X_t, a, b)

    ns = X_s.shape[0]
    nt = X_t.shape[0]
    r = rank

    if init == "random":
        nx.seed(seed=random_state)

        # Init g
        g = nx.abs(nx.randn(r, type_as=X_s)) + 1
        g = g / nx.sum(g)

        # Init Q
        Q = nx.abs(nx.randn(ns, r, type_as=X_s)) + 1
        Q = (Q.T * (a / nx.sum(Q, axis=1))).T

        # Init R
        R = nx.abs(nx.randn(nt, rank, type_as=X_s)) + 1
        R = (R.T * (b / nx.sum(R, axis=1))).T

    if init == "deterministic":
        # Init g
        g = nx.ones(rank) / rank

        lambda_1 = min(nx.min(a), nx.min(g), nx.min(b)) / 2
        a1 = nx.arange(start=1, stop=ns + 1, type_as=X_s)
        a1 = a1 / nx.sum(a1)
        a2 = (a - lambda_1 * a1) / (1 - lambda_1)

        b1 = nx.arange(start=1, stop=nt + 1, type_as=X_s)
        b1 = b1 / nx.sum(b1)
        b2 = (b - lambda_1 * b1) / (1 - lambda_1)

        g1 = nx.arange(start=1, stop=rank + 1, type_as=X_s)
        g1 = g1 / nx.sum(g1)
        g2 = (g - lambda_1 * g1) / (1 - lambda_1)

        # Init Q
        Q1 = lambda_1 * nx.dot(a1[:, None], nx.reshape(g1, (1, -1)))
        Q2 = (1 - lambda_1) * nx.dot(a2[:, None], nx.reshape(g2, (1, -1)))
        Q = Q1 + Q2

        # Init R
        R1 = lambda_1 * nx.dot(b1[:, None], nx.reshape(g1, (1, -1)))
        R2 = (1 - lambda_1) * nx.dot(b2[:, None], nx.reshape(g2, (1, -1)))
        R = R1 + R2

    if init == "kmeans":
        if sklearn_import:
            # Init g
            g = nx.ones(rank, type_as=X_s) / rank

            # Init Q
            kmeans_Xs = sklearn.cluster.KMeans(
                n_clusters=rank, random_state=random_state, n_init="auto"
            )
            kmeans_Xs.fit(X_s)
            Z_Xs = nx.from_numpy(kmeans_Xs.cluster_centers_)
            C_Xs = dist(X_s, Z_Xs)  # shape (ns, rank)
            C_Xs = C_Xs / nx.max(C_Xs)
            Q = sinkhorn(a, g, C_Xs, reg=reg_init, numItermax=10000, stopThr=1e-3)

            # Init R
            kmeans_Xt = sklearn.cluster.KMeans(
                n_clusters=rank, random_state=random_state, n_init="auto"
            )
            kmeans_Xt.fit(X_t)
            Z_Xt = nx.from_numpy(kmeans_Xt.cluster_centers_)
            C_Xt = dist(X_t, Z_Xt)  # shape (nt, rank)
            C_Xt = C_Xt / nx.max(C_Xt)
            R = sinkhorn(b, g, C_Xt, reg=reg_init, numItermax=10000, stopThr=1e-3)

        else:
            raise ImportError(
                "Scikit-learn should be installed to use the 'kmeans' init."
            )

    return Q, R, g


[docs] def compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost, nx=None): """ Compute the low rank decomposition of a squared euclidean distance matrix. This function won't work for other distance metrics. See "Section 3.5, proposition 1" Parameters ---------- X_s : array-like, shape (n_samples_a, dim) samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain rescale_cost : bool Rescale the low rank factorization of the sqeuclidean cost matrix nx : default None POT backend Returns ---------- M1 : array-like, shape (n_samples_a, dim+2) First low rank decomposition of the distance matrix M2 : array-like, shape (n_samples_b, dim+2) Second low rank decomposition of the distance matrix References ----------- .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. """ if nx is None: nx = get_backend(X_s, X_t) ns = X_s.shape[0] nt = X_t.shape[0] # First low rank decomposition of the cost matrix (A) array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1)) array2 = nx.ones((ns, 1), type_as=X_s) M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1) # Second low rank decomposition of the cost matrix (B) array1 = nx.ones((nt, 1), type_as=X_s) array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1)) M2 = nx.concatenate((array1, array2, X_t), axis=1) if rescale_cost is True: M1 = M1 / nx.sqrt(nx.max(M1)) M2 = M2 / nx.sqrt(nx.max(M2)) return M1, M2
def _LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): """ Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. This function is specific to lowrank_sinkhorn. Parameters ---------- eps1 : array-like, shape (n_samples_a, r) First input parameter of the Dykstra algorithm eps2 : array-like, shape (n_samples_b, r) Second input parameter of the Dykstra algorithm eps3 : array-like, shape (r,) Third input parameter of the Dykstra algorithm p1 : array-like, shape (n_samples_a,) Samples weights in the source domain (same as "a" in lowrank_sinkhorn) p2 : array-like, shape (n_samples_b,) Samples weights in the target domain (same as "b" in lowrank_sinkhorn) alpha: int Lower bound for the weight vector g (same as "alpha" in lowrank_sinkhorn) stopThr : float Stop threshold on error numItermax : int Max number of iterations warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. nx : default None POT backend Returns ---------- Q : array-like, shape (n_samples_a, r) Dykstra update of the first low-rank matrix decomposition Q R: array-like, shape (n_samples_b, r) Dykstra update of the Second low-rank matrix decomposition R g : array-like, shape (r, ) Dykstra update of the weight vector g References ---------- .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn Factorization". In International Conference on Machine Learning. """ # POT backend if None if nx is None: nx = get_backend(eps1, eps2, eps3, p1, p2) # ----------------- Initialisation of Dykstra algorithm ----------------- r = len(eps3) # rank g_ = nx.copy(eps3) # \tilde{g} q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 v1_, v2_ = ( nx.ones(r, type_as=p1), nx.ones(r, type_as=p1), ) # \tilde{v}^{(1)}, \tilde{v}^{(2)} q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} err = 1 # initial error # --------------------- Dykstra algorithm ------------------------- # See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper for ii in range(numItermax): if err > stopThr: # Compute u^{(1)} and u^{(2)} u1 = p1 / nx.dot(eps1, v1_) u2 = p2 / nx.dot(eps2, v2_) # Compute g, g^{(3)}_1 and update \tilde{g} g = nx.maximum(alpha, g_ * q3_1) q3_1 = (g_ * q3_1) / g g_ = nx.copy(g) # Compute new value of g with \prod prod1 = (v1_ * q1) * nx.dot(eps1.T, u1) prod2 = (v2_ * q2) * nx.dot(eps2.T, u2) g = (g_ * q3_2 * prod1 * prod2) ** (1 / 3) # Compute v^{(1)} and v^{(2)} v1 = g / nx.dot(eps1.T, u1) v2 = g / nx.dot(eps2.T, u2) # Compute q^{(1)}, q^{(2)} and q^{(3)}_2 q1 = (v1_ * q1) / v1 q2 = (v2_ * q2) / v2 q3_2 = (g_ * q3_2) / g # Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g} v1_, v2_ = nx.copy(v1), nx.copy(v2) g_ = nx.copy(g) # Compute error err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) err = err1 + err2 else: break else: if warn: warnings.warn( "Dykstra did not converge. You might want to " "increase the number of iterations `numItermax` " ) # Compute low rank matrices Q, R Q = u1[:, None] * eps1 * v1[None, :] R = u2[:, None] * eps2 * v2[None, :] return Q, R, g
[docs] def lowrank_sinkhorn( X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, rescale_cost=True, init="random", reg_init=1e-1, seed_init=49, gamma_init="rescale", numItermax=2000, stopThr=1e-7, warn=True, log=False, ): r""" Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the couplings. The function solves the following optimization problem: .. math:: \mathop{\inf_{(\mathbf{Q},\mathbf{R},\mathbf{g}) \in \mathcal{C}(\mathbf{a},\mathbf{b},r)}} \langle \mathbf{C}, \mathbf{Q}\mathrm{diag}(1/\mathbf{g})\mathbf{R}^\top \rangle - \mathrm{reg} \cdot H((\mathbf{Q}, \mathbf{R}, \mathbf{g})) where : - :math:`\mathbf{C}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`H((\mathbf{Q}, \mathbf{R}, \mathbf{g}))` is the values of the three respective entropies evaluated for each term. - :math:`\mathbf{Q}` and :math:`\mathbf{R}` are the low-rank matrix decomposition of the OT plan - :math:`\mathbf{g}` is the weight vector for the low-rank decomposition of the OT plan - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - :math:`r` is the rank of the OT plan - :math:`\mathcal{C}(\mathbf{a}, \mathbf{b}, r)` are the low-rank couplings of the OT problem Parameters ---------- X_s : array-like, shape (n_samples_a, dim) samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain a : array-like, shape (n_samples_a,) samples weights in the source domain b : array-like, shape (n_samples_b,) samples weights in the target domain reg : float, optional Regularization term >0 rank : int, optional. Default is None. (>0) Nonnegative rank of the OT plan. If None, min(ns, nt) is considered. alpha : int, optional. Default is 1e-10. (>0 and <1/r) Lower bound for the weight vector g. rescale_cost : bool, optional. Default is False Rescale the low rank factorization of the sqeuclidean cost matrix init : str, optional. Default is 'random'. Initialization strategy for the low rank couplings. 'random', 'deterministic' or 'kmeans' reg_init : float, optional. Default is 1e-1. (>0) Regularization term for a 'kmeans' init. If None, 1 is considered. seed_init : int, optional. Default is 49. (>0) Random state for a 'random' or 'kmeans' init strategy. gamma_init : str, optional. Default is "rescale". Initialization strategy for gamma. 'rescale', or 'theory' Gamma is a constant that scales the convergence criterion of the Mirror Descent optimization scheme used to compute the low-rank couplings (Q, R and g) numItermax : int, optional. Default is 2000. Max number of iterations for the Dykstra algorithm stopThr : float, optional. Default is 1e-7. Stop threshold on error (>0) in Dykstra warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. log : bool, optional record log if True Returns --------- Q : array-like, shape (n_samples_a, r) First low-rank matrix decomposition of the OT plan R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) Weight vector for the low-rank decomposition of the OT log : dict (lazy_plan, value and value_linear) log dictionary return only if log==True in parameters References ---------- .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn Factorization". In International Conference on Machine Learning. """ # POT backend nx = get_backend(X_s, X_t) ns, nt = X_s.shape[0], X_t.shape[0] # Initialize weights a, b if a is None: a = unif(ns, type_as=X_s) if b is None: b = unif(nt, type_as=X_t) # Compute rank (see Section 3.1, def 1) r = rank if rank is None: r = min(ns, nt) else: r = min(ns, nt, rank) if r <= 0: raise ValueError("The rank parameter cannot have a negative value") # Dykstra won't converge if 1/rank < alpha (see Section 3.2) if 1 / r < alpha: raise ValueError( "alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( a=alpha, r=1 / rank ) ) # Low rank decomposition of the sqeuclidean cost matrix M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, rescale_cost, nx) # Initialize the low rank matrices Q, R, g Q, R, g = _init_lr_sinkhorn(X_s, X_t, a, b, r, init, reg_init, seed_init, nx=nx) # Gamma initialization if gamma_init == "theory": L = nx.sqrt( 3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + (reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 ) gamma = 1 / (2 * L) if gamma_init not in ["rescale", "theory"]: raise ( NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init)) ) # -------------------------- Low rank algorithm ------------------------------ # see "Section 3.3, Algorithm 3 LOT" for ii in range(100): # Compute C*R dot using the lr decomposition of C CR = nx.dot(M2.T, R) CR_ = nx.dot(M1, CR) diag_g = (1 / g)[None, :] CR_g = CR_ * diag_g # Compute C.T * Q using the lr decomposition of C CQ = nx.dot(M1.T, Q) CQ_ = nx.dot(M2, CQ) CQ_g = CQ_ * diag_g # Compute omega omega = nx.diag(nx.dot(Q.T, CR_)) # Rescale gamma at each iteration if gamma_init == "rescale": norm_1 = nx.max(nx.abs(CR_ * diag_g + reg * nx.log(Q))) ** 2 norm_2 = nx.max(nx.abs(CQ_ * diag_g + reg * nx.log(R))) ** 2 norm_3 = nx.max(nx.abs(-omega * diag_g)) ** 2 gamma = 10 / max(norm_1, norm_2, norm_3) eps1 = nx.exp(-gamma * CR_g - ((gamma * reg) - 1) * nx.log(Q)) eps2 = nx.exp(-gamma * CQ_g - ((gamma * reg) - 1) * nx.log(R)) eps3 = nx.exp((gamma * omega / (g**2)) - (gamma * reg - 1) * nx.log(g)) # LR Dykstra algorithm Q, R, g = _LR_Dysktra( eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx ) Q = Q + 1e-16 R = R + 1e-16 g = g + 1e-16 # ----------------- Compute lazy_plan, value and value_linear ------------------ # see "Section 3.2: The Low-rank OT Problem" in the paper # Compute lazy plan (using LazyTensor class) lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) # Compute value_linear (using trace formula) v1 = nx.dot(Q.T, M1) v2 = nx.dot(R, (v1.T * diag_g).T) value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) # Compute value with entropy reg (see "Section 3.2" in the paper) reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R value = value_linear + reg * (reg_Q + reg_g + reg_R) if log: dict_log = dict() dict_log["value"] = value dict_log["value_linear"] = value_linear dict_log["lazy_plan"] = lazy_plan return Q, R, g, dict_log return Q, R, g