Low rank Gromov-Wasserstein solver
# Author: Laurène David <laurene.david@ip-paris.fr>
# License: MIT License
import warnings
from ..utils import unif, get_lowrank_lazytensor
from ..backend import get_backend
from ..lowrank import compute_lr_sqeuclidean_matrix, _init_lr_sinkhorn, _LR_Dysktra
def _flat_product_operator(X, nx=None):
Implementation of the flattened out-product operator.
This function is used in low rank gromov wasserstein to compute the low rank decomposition of
a cost matrix's squared hadamard product (page 6 in paper).
X: array-like, shape (n_samples, n_col)
Input matrix for operator
nx: default None
POT backend
X_flat: array-like, shape (n_samples, n_col**2)
Matrix with flattened out-product operator applied on each row
.. [67] Scetbon, M., Peyré, G. & Cuturi, M. (2022).
"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs".
In International Conference on Machine Learning (ICML), 2022.
if nx is None:
nx = get_backend(X)
n = X.shape[0]
x1 = X[0, :][:, None]
X_flat = nx.dot(x1, x1.T).flatten()[:, None]
for i in range(1, n):
x = X[i, :][:, None]
x_out = nx.dot(x, x.T).flatten()[:, None]
X_flat = nx.concatenate((X_flat, x_out), axis=1)
X_flat = X_flat.T
return X_flat
def lowrank_gromov_wasserstein_samples(
Solve the entropic regularization Gromov-Wasserstein transport problem under low-nonnegative rank constraints
on the couplings and cost matrices.
Squared euclidean distance matrices are considered for the target and source distributions.
The function solves the following optimization problem:
.. math::
\mathop{\min_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \mathcal{Q}_{A,B}(Q\mathrm{diag}(1/g)R^T) -
\epsilon \cdot H((Q,R,g))
where :
- :math:`A` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the source domain.
- :math:`B` is the (`dim_a`, `dim_a`) square pairwise cost matrix of the target domain.
- :math:`\mathcal{Q}_{A,B}` is quadratic objective function of the Gromov Wasserstein plan.
- :math:`Q` and `R` are the low-rank matrix decomposition of the Gromov-Wasserstein plan.
- :math:`g` is the weight vector for the low-rank decomposition of the Gromov-Wasserstein plan.
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1).
- :math:`r` is the rank of the Gromov-Wasserstein plan.
- :math:`\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem.
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term.
X_s : array-like, shape (n_samples_a, dim_Xs)
Samples in the source domain
X_t : array-like, shape (n_samples_b, dim_Xt)
Samples in the target domain
a : array-like, shape (n_samples_a,), optional
Samples weights in the source domain
If let to its default value None, uniform distribution is taken.
b : array-like, shape (n_samples_b,), optional
Samples weights in the target domain
If let to its default value None, uniform distribution is taken.
reg : float, optional
Regularization term >=0
rank : int, optional. Default is None. (>0)
Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.
alpha : int, optional. Default is 1e-10. (>0 and <1/r)
Lower bound for the weight vector g.
rescale_cost : bool, optional. Default is False
Rescale the low rank factorization of the sqeuclidean cost matrix
seed_init : int, optional. Default is 49. (>0)
Random state for the 'random' initialization of low rank couplings
gamma_init : str, optional. Default is "rescale".
Initialization strategy for gamma. 'rescale', or 'theory'
Gamma is a constant that scales the convergence criterion of the Mirror Descent
optimization scheme used to compute the low-rank couplings (Q, R and g)
numItermax : int, optional. Default is 1000.
Max number of iterations for Low Rank GW
stopThr : float, optional. Default is 1e-4.
Stop threshold on error (>0) for Low Rank GW
The error is the sum of Kullback Divergences computed for each low rank
coupling (Q, R and g) and scaled using gamma.
numItermax_dykstra : int, optional. Default is 2000.
Max number of iterations for the Dykstra algorithm
stopThr_dykstra : float, optional. Default is 1e-7.
Stop threshold on error (>0) in Dykstra
cost_factorized_Xs: tuple, optional. Default is None
Tuple with two pre-computed low rank decompositions (A1, A2) of the source cost
matrix. Both matrices should have a shape of (n_samples_a, dim_Xs + 2).
If None, the low rank cost matrices will be computed as sqeuclidean cost matrices.
cost_factorized_Xt: tuple, optional. Default is None
Tuple with two pre-computed low rank decompositions (B1, B2) of the target cost
matrix. Both matrices should have a shape of (n_samples_b, dim_Xt + 2).
If None, the low rank cost matrices will be computed as sqeuclidean cost matrices.
warn : bool, optional
if True, raises a warning if the low rank GW algorithm doesn't convergence.
warn_dykstra: bool, optional
if True, raises a warning if the Dykstra algorithm doesn't convergence.
log : bool, optional
record log if True
Q : array-like, shape (n_samples_a, r)
First low-rank matrix decomposition of the OT plan
R: array-like, shape (n_samples_b, r)
Second low-rank matrix decomposition of the OT plan
g : array-like, shape (r, )
Weight vector for the low-rank decomposition of the OT
log : dict (lazy_plan, value and value_linear)
log dictionary return only if log==True in parameters
# POT backend
nx = get_backend(X_s, X_t)
ns, nt = X_s.shape[0], X_t.shape[0]
# Initialize weights a, b
if a is None:
a = unif(ns, type_as=X_s)
if b is None:
b = unif(nt, type_as=X_t)
# Compute rank (see Section 3.1, def 1)
r = rank
if rank is None:
r = min(ns, nt)
r = min(ns, nt, rank)
if r <= 0:
raise ValueError("The rank parameter cannot have a negative value")
# Dykstra won't converge if 1/rank < alpha (see Section 3.2)
if 1 / r < alpha:
raise ValueError(
"alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format(
a=alpha, r=1 / rank
if cost_factorized_Xs is not None:
A1, A2 = cost_factorized_Xs
A1, A2 = compute_lr_sqeuclidean_matrix(X_s, X_s, rescale_cost, nx=nx)
if cost_factorized_Xt is not None:
B1, B2 = cost_factorized_Xt
B1, B2 = compute_lr_sqeuclidean_matrix(X_t, X_t, rescale_cost, nx=nx)
# Initial values for LR couplings (Q, R, g) with LOT
Q, R, g = _init_lr_sinkhorn(
X_s, X_t, a, b, r, init="random", random_state=seed_init, reg_init=None, nx=nx
# Gamma initialization
if gamma_init == "theory":
L = (27 * nx.norm(A1) * nx.norm(A2)) / alpha**4
gamma = 1 / (2 * L)
if gamma_init not in ["rescale", "theory"]:
raise (
NotImplementedError('Not implemented gamma_init="{}"'.format(gamma_init))
# initial value of error
err = 1
for ii in range(numItermax):
Q_prev = Q
R_prev = R
g_prev = g
if err > stopThr:
# Compute cost matrices
C1 = nx.dot(A2.T, Q * (1 / g)[None, :])
C1 = -4 * nx.dot(A1, C1)
C2 = nx.dot(R.T, B1)
C2 = nx.dot(C2, B2.T)
diag_g = (1 / g)[None, :]
# Compute C*R dot using the lr decomposition of C
CR = nx.dot(C2, R)
CR = nx.dot(C1, CR)
CR_g = CR * diag_g
# Compute C.T * Q using the lr decomposition of C
CQ = nx.dot(C1.T, Q)
CQ = nx.dot(C2.T, CQ)
CQ_g = CQ * diag_g
# Compute omega
omega = nx.diag(nx.dot(Q.T, CR))
# Rescale gamma at each iteration
if gamma_init == "rescale":
norm_1 = nx.max(nx.abs(CR_g + reg * nx.log(Q))) ** 2
norm_2 = nx.max(nx.abs(CQ_g + reg * nx.log(R))) ** 2
norm_3 = nx.max(nx.abs(-omega * (diag_g**2))) ** 2
gamma = 10 / max(norm_1, norm_2, norm_3)
K1 = nx.exp(-gamma * CR_g - ((gamma * reg) - 1) * nx.log(Q))
K2 = nx.exp(-gamma * CQ_g - ((gamma * reg) - 1) * nx.log(R))
K3 = nx.exp((gamma * omega / (g**2)) - (gamma * reg - 1) * nx.log(g))
# Update couplings with LR Dykstra algorithm
Q, R, g = _LR_Dysktra(
# Update error with kullback-divergence
err_1 = ((1 / gamma) ** 2) * (nx.kl_div(Q, Q_prev) + nx.kl_div(Q_prev, Q))
err_2 = ((1 / gamma) ** 2) * (nx.kl_div(R, R_prev) + nx.kl_div(R_prev, R))
err_3 = ((1 / gamma) ** 2) * (nx.kl_div(g, g_prev) + nx.kl_div(g_prev, g))
err = err_1 + err_2 + err_3
# fix divide by zero
Q = Q + 1e-16
R = R + 1e-16
g = g + 1e-16
if warn:
"Low Rank GW did not converge. You might want to "
"increase the number of iterations `numItermax` "
# Update low rank costs
C1 = nx.dot(A2.T, Q * (1 / g)[None, :])
C1 = -4 * nx.dot(A1, C1)
C2 = nx.dot(R.T, B1)
C2 = nx.dot(C2, B2.T)
# Compute lazy plan (using LazyTensor class)
lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g)
# Compute value_quad
A1_, A2_ = _flat_product_operator(A1, nx), _flat_product_operator(A2, nx)
B1_, B2_ = _flat_product_operator(B1, nx), _flat_product_operator(B2, nx)
x_ = nx.dot(A1_, nx.dot(A2_.T, a))
y_ = nx.dot(B1_, nx.dot(B2_.T, b))
c1 = nx.dot(x_, a) + nx.dot(y_, b)
G = nx.dot(C1, nx.dot(C2, R))
G = nx.dot(Q.T, G * diag_g)
value_quad = c1 + nx.trace(G) / 2
if reg != 0:
reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q
reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g
reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R
value = value_quad + reg * (reg_Q + reg_g + reg_R)
value = value_quad
if log:
dict_log = dict()
dict_log["value"] = value
dict_log["value_quad"] = value_quad
dict_log["lazy_plan"] = lazy_plan
return Q, R, g, dict_log
return Q, R, g