Source code for ot.gromov._lowrank

"""
Low rank Gromov-Wasserstein solver
"""

# Author: Laurène David <laurene.david@ip-paris.fr>
#
# License: MIT License

import warnings
from ..utils import unif, get_lowrank_lazytensor
from ..backend import get_backend
from ..lowrank import compute_lr_sqeuclidean_matrix, _init_lr_sinkhorn, _LR_Dysktra


def _flat_product_operator(X, nx=None):
    r"""
    Implementation of the flattened out-product operator.

    This function is used in low rank gromov wasserstein to compute the low rank decomposition of
    a cost matrix's squared hadamard product (page 6 in paper).

    Parameters
    ----------
    X: array-like, shape (n_samples, n_col)
        Input matrix for operator

    nx: default None
        POT backend

    Returns
    ----------
    X_flat: array-like, shape (n_samples, n_col**2)
        Matrix with flattened out-product operator applied on each row

    References
    ----------
    .. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022).
        "Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs".
        In International Conference on Machine Learning (ICML), 2022.

    """

    if nx is None:
        nx = get_backend(X)

    n = X.shape[0]
    x1 = X[0, :][:, None]
    X_flat = nx.dot(x1, x1.T).flatten()[:, None]

    for i in range(1, n):
        x = X[i, :][:, None]
        x_out = nx.dot(x, x.T).flatten()[:, None]
        X_flat = nx.concatenate((X_flat, x_out), axis=1)

    X_flat = X_flat.T

    return X_flat


[docs] def lowrank_gromov_wasserstein_samples( X_s, X_t, a=None, b=None, reg=0, rank=None, alpha=1e-10, gamma_init="rescale", rescale_cost=True, cost_factorized_Xs=None, cost_factorized_Xt=None, stopThr=1e-4, numItermax=1000, stopThr_dykstra=1e-3, numItermax_dykstra=10000, seed_init=49, warn=True, warn_dykstra=False, log=False, ): r""" Solve the entropic regularization Gromov-Wasserstein transport problem under low-nonnegative rank constraints on the couplings and cost matrices. Squared euclidean distance matrices are considered for the target and source distributions. The function solves the following optimization problem: .. math:: \mathop{\min_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \mathcal{Q}_{A,B}(Q\mathrm{diag}(1/g)R^T) - \epsilon \cdot H((Q,R,g)) where : - :math:`A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain. - :math:`B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain. - :math:`\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan. - :math:`Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan. - :math:`g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan. - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1). - :math:`r` is the rank of the Gromov-Wasserstein plan. - :math:`\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem. - :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. Parameters ---------- X_s : array-like, shape (n_samples_a, dim_Xs) Samples in the source domain X_t : array-like, shape (n_samples_b, dim_Xt) Samples in the target domain a : array-like, shape (n_samples_a,), optional Samples weights in the source domain If let to its default value None, uniform distribution is taken. b : array-like, shape (n_samples_b,), optional Samples weights in the target domain If let to its default value None, uniform distribution is taken. reg : float, optional Regularization term >=0 rank : int, optional. Default is None. (>0) Nonnegative rank of the OT plan. If None, min(ns, nt) is considered. alpha : int, optional. Default is 1e-10. (>0 and <1/r) Lower bound for the weight vector g. rescale_cost : bool, optional. Default is False Rescale the low rank factorization of the sqeuclidean cost matrix seed_init : int, optional. Default is 49. (>0) Random state for the 'random' initialization of low rank couplings gamma_init : str, optional. Default is "rescale". Initialization strategy for gamma. 'rescale', or 'theory' Gamma is a constant that scales the convergence criterion of the Mirror Descent optimization scheme used to compute the low-rank couplings (Q, R and g) numItermax : int, optional. Default is 1000. Max number of iterations for Low Rank GW stopThr : float, optional. Default is 1e-4. Stop threshold on error (>0) for Low Rank GW The error is the sum of Kullback Divergences computed for each low rank coupling (Q, R and g) and scaled using gamma. numItermax_dykstra : int, optional. Default is 2000. Max number of iterations for the Dykstra algorithm stopThr_dykstra : float, optional. Default is 1e-7. Stop threshold on error (>0) in Dykstra cost_factorized_Xs: tuple, optional. Default is None Tuple with two pre-computed low rank decompositions (A1, A2) of the source cost matrix. Both matrices should have a shape of (n_samples_a, dim_Xs + 2). If None, the low rank cost matrices will be computed as sqeuclidean cost matrices. cost_factorized_Xt: tuple, optional. Default is None Tuple with two pre-computed low rank decompositions (B1, B2) of the target cost matrix. Both matrices should have a shape of (n_samples_b, dim_Xt + 2). If None, the low rank cost matrices will be computed as sqeuclidean cost matrices. warn : bool, optional if True, raises a warning if the low rank GW algorithm doesn't convergence. warn_dykstra: bool, optional if True, raises a warning if the Dykstra algorithm doesn't convergence. log : bool, optional record log if True Returns --------- Q : array-like, shape (n_samples_a, r) First low-rank matrix decomposition of the OT plan R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) Weight vector for the low-rank decomposition of the OT log : dict (lazy_plan, value and value_linear) log dictionary return only if log==True in parameters References ---------- .. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022). "Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs". In International Conference on Machine Learning (ICML), 2022. """ # POT backend nx = get_backend(X_s, X_t) ns, nt = X_s.shape[0], X_t.shape[0] # Initialize weights a, b if a is None: a = unif(ns, type_as=X_s) if b is None: b = unif(nt, type_as=X_t) # Compute rank (see Section 3.1, def 1) r = rank if rank is None: r = min(ns, nt) else: r = min(ns, nt, rank) if r <= 0: raise ValueError("The rank parameter cannot have a negative value") # Dykstra won't converge if 1/rank < alpha (see Section 3.2) if 1 / r < alpha: raise ValueError( "alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( a=alpha, r=1 / rank ) ) if cost_factorized_Xs is not None: A1, A2 = cost_factorized_Xs else: A1, A2 = compute_lr_sqeuclidean_matrix(X_s, X_s, rescale_cost, nx=nx) if cost_factorized_Xt is not None: B1, B2 = cost_factorized_Xt else: B1, B2 = compute_lr_sqeuclidean_matrix(X_t, X_t, rescale_cost, nx=nx) # Initial values for LR couplings (Q, R, g) with LOT Q, R, g = _init_lr_sinkhorn( X_s, X_t, a, b, r, init="random", random_state=seed_init, reg_init=None, nx=nx ) # Gamma initialization if gamma_init == "theory": L = (27 * nx.norm(A1) * nx.norm(A2)) / alpha**4 gamma = 1 / (2 * L) if gamma_init not in ["rescale", "theory"]: raise ( NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init)) ) # initial value of error err = 1 for ii in range(numItermax): Q_prev = Q R_prev = R g_prev = g if err > stopThr: # Compute cost matrices C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) C1 = -4 * nx.dot(A1, C1) C2 = nx.dot(R.T, B1) C2 = nx.dot(C2, B2.T) diag_g = (1 / g)[None, :] # Compute C*R dot using the lr decomposition of C CR = nx.dot(C2, R) CR = nx.dot(C1, CR) CR_g = CR * diag_g # Compute C.T * Q using the lr decomposition of C CQ = nx.dot(C1.T, Q) CQ = nx.dot(C2.T, CQ) CQ_g = CQ * diag_g # Compute omega omega = nx.diag(nx.dot(Q.T, CR)) # Rescale gamma at each iteration if gamma_init == "rescale": norm_1 = nx.max(nx.abs(CR_g + reg * nx.log(Q))) ** 2 norm_2 = nx.max(nx.abs(CQ_g + reg * nx.log(R))) ** 2 norm_3 = nx.max(nx.abs(-omega * (diag_g**2))) ** 2 gamma = 10 / max(norm_1, norm_2, norm_3) K1 = nx.exp(-gamma * CR_g - ((gamma * reg) - 1) * nx.log(Q)) K2 = nx.exp(-gamma * CQ_g - ((gamma * reg) - 1) * nx.log(R)) K3 = nx.exp((gamma * omega / (g**2)) - (gamma * reg - 1) * nx.log(g)) # Update couplings with LR Dykstra algorithm Q, R, g = _LR_Dysktra( K1, K2, K3, a, b, alpha, stopThr_dykstra, numItermax_dykstra, warn_dykstra, nx, ) # Update error with kullback-divergence err_1 = ((1 / gamma) ** 2) * (nx.kl_div(Q, Q_prev) + nx.kl_div(Q_prev, Q)) err_2 = ((1 / gamma) ** 2) * (nx.kl_div(R, R_prev) + nx.kl_div(R_prev, R)) err_3 = ((1 / gamma) ** 2) * (nx.kl_div(g, g_prev) + nx.kl_div(g_prev, g)) err = err_1 + err_2 + err_3 # fix divide by zero Q = Q + 1e-16 R = R + 1e-16 g = g + 1e-16 else: break else: if warn: warnings.warn( "Low Rank GW did not converge. You might want to " "increase the number of iterations `numItermax` " ) # Update low rank costs C1 = nx.dot(A2.T, Q * (1 / g)[None, :]) C1 = -4 * nx.dot(A1, C1) C2 = nx.dot(R.T, B1) C2 = nx.dot(C2, B2.T) # Compute lazy plan (using LazyTensor class) lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) # Compute value_quad A1_, A2_ = _flat_product_operator(A1, nx), _flat_product_operator(A2, nx) B1_, B2_ = _flat_product_operator(B1, nx), _flat_product_operator(B2, nx) x_ = nx.dot(A1_, nx.dot(A2_.T, a)) y_ = nx.dot(B1_, nx.dot(B2_.T, b)) c1 = nx.dot(x_, a) + nx.dot(y_, b) G = nx.dot(C1, nx.dot(C2, R)) G = nx.dot(Q.T, G * diag_g) value_quad = c1 + nx.trace(G) / 2 if reg != 0: reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R value = value_quad + reg * (reg_Q + reg_g + reg_R) else: value = value_quad if log: dict_log = dict() dict_log["value"] = value dict_log["value_quad"] = value_quad dict_log["lazy_plan"] = lazy_plan return Q, R, g, dict_log return Q, R, g