# Source code for ot.gromov._utils

# -*- coding: utf-8 -*-
"""
Gromov-Wasserstein and Fused-Gromov-Wasserstein utils.
"""

# Author: Erwan Vautier <erwan.vautier@gmail.com>
#         Nicolas Courty <ncourty@irisa.fr>
#         Rémi Flamary <remi.flamary@unice.fr>
#         Titouan Vayer <titouan.vayer@irisa.fr>
#         Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
#

from ..utils import list_to_array
from ..backend import get_backend

[docs]
def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None):
r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation

Returns the value of :math:\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T} with the
selected loss function as the loss function of Gromov-Wasserstein discrepancy.

The matrices are computed as described in Proposition 1 in :ref:[12] <references-init-matrix>

Where :

- :math:\mathbf{C_1}: Metric cost matrix in the source space
- :math:\mathbf{C_2}: Metric cost matrix in the target space
- :math:\mathbf{T}: A coupling between those two spaces

The square-loss function :math:L(a, b) = |a - b|^2 is read as :

.. math::

L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)

\mathrm{with} \ f_1(a) &= a^2

f_2(b) &= b^2

h_1(a) &= a

h_2(b) &= 2b

The kl-loss function :math:L(a, b) = a \log\left(\frac{a}{b}\right) - a + b is read as :

.. math::

L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)

\mathrm{with} \ f_1(a) &= a \log(a) - a

f_2(b) &= b

h_1(a) &= a

h_2(b) &= \log(b)

Parameters
----------
C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
p : array-like, shape (ns,)
Probability distribution in the source space
q : array-like, shape (nt,)
Probability distribution in the target space
loss_fun : str, optional
Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
constC : array-like, shape (ns, nt)
Constant :math:\mathbf{C} matrix in Eq. (6)
hC1 : array-like, shape (ns, ns)
:math:\mathbf{h1}(\mathbf{C1}) matrix in Eq. (6)
hC2 : array-like, shape (nt, nt)
:math:\mathbf{h2}(\mathbf{C2}) matrix in Eq. (6)

.. _references-init-matrix:
References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
if nx is None:
C1, C2, p, q = list_to_array(C1, C2, p, q)
nx = get_backend(C1, C2, p, q)

if loss_fun == 'square_loss':
def f1(a):
return (a**2)

def f2(b):
return (b**2)

def h1(a):
return a

def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * nx.log(a + 1e-15) - a

def f2(b):
return b

def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
else:
raise ValueError(f"Unknown loss_fun='{loss_fun}'. Use one of: {'square_loss', 'kl_loss'}.")

constC1 = nx.dot(
nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
nx.ones((1, len(q)), type_as=q)
)
constC2 = nx.dot(
nx.ones((len(p), 1), type_as=p),
nx.dot(nx.reshape(q, (1, -1)), f2(C2).T)
)
constC = constC1 + constC2
hC1 = h1(C1)
hC2 = h2(C2)

return constC, hC1, hC2

[docs]
def tensor_product(constC, hC1, hC2, T, nx=None):
r"""Return the tensor for Gromov-Wasserstein fast computation

The tensor is computed as described in Proposition 1 Eq. (6) in :ref:[12] <references-tensor-product>

Parameters
----------
constC : array-like, shape (ns, nt)
Constant :math:\mathbf{C} matrix in Eq. (6)
hC1 : array-like, shape (ns, ns)
:math:\mathbf{h1}(\mathbf{C1}) matrix in Eq. (6)
hC2 : array-like, shape (nt, nt)
:math:\mathbf{h2}(\mathbf{C2}) matrix in Eq. (6)
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
tens : array-like, shape (ns, nt)
:math:\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T} tensor-matrix multiplication result

.. _references-tensor-product:
References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
if nx is None:
constC, hC1, hC2, T = list_to_array(constC, hC1, hC2, T)
nx = get_backend(constC, hC1, hC2, T)

A = - nx.dot(
nx.dot(hC1, T), hC2.T
)
tens = constC + A
# tens -= tens.min()
return tens

[docs]
def gwloss(constC, hC1, hC2, T, nx=None):
r"""Return the Loss for Gromov-Wasserstein

The loss is computed as described in Proposition 1 Eq. (6) in :ref:[12] <references-gwloss>

Parameters
----------
constC : array-like, shape (ns, nt)
Constant :math:\mathbf{C} matrix in Eq. (6)
hC1 : array-like, shape (ns, ns)
:math:\mathbf{h1}(\mathbf{C1}) matrix in Eq. (6)
hC2 : array-like, shape (nt, nt)
:math:\mathbf{h2}(\mathbf{C2}) matrix in Eq. (6)
T : array-like, shape (ns, nt)
Current value of transport matrix :math:\mathbf{T}
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
loss : float
Gromov-Wasserstein loss

.. _references-gwloss:
References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""

tens = tensor_product(constC, hC1, hC2, T, nx)
if nx is None:
tens, T = list_to_array(tens, T)
nx = get_backend(tens, T)

return nx.sum(tens * T)

[docs]
def gwggrad(constC, hC1, hC2, T, nx=None):

The gradient is computed as described in Proposition 2 in :ref:[12] <references-gwggrad>

Parameters
----------
constC : array-like, shape (ns, nt)
Constant :math:\mathbf{C} matrix in Eq. (6)
hC1 : array-like, shape (ns, ns)
:math:\mathbf{h1}(\mathbf{C1}) matrix in Eq. (6)
hC2 : array-like, shape (nt, nt)
:math:\mathbf{h2}(\mathbf{C2}) matrix in Eq. (6)
T : array-like, shape (ns, nt)
Current value of transport matrix :math:\mathbf{T}
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
grad : array-like, shape (ns, nt)

References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
return 2 * tensor_product(constC, hC1, hC2,
T, nx)  # [12] Prop. 2 misses a 2 factor

[docs]
def update_square_loss(p, lambdas, T, Cs, nx=None):
r"""
Updates :math:\mathbf{C} according to the L2 Loss kernel with the S
:math:\mathbf{T}_s couplings calculated at each iteration of the GW
barycenter problem in :ref:[12]:

.. math::

\mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)

Where :

- :math:\mathbf{C}_s: metric cost matrix
- :math:\mathbf{p}_s: distribution

Parameters
----------
p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
List of the S spaces' weights.
T : list of S array-like of shape (N, ns)
The S :math:\mathbf{T}_s couplings calculated at each iteration.
Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
----------
C : array-like, shape (nt, nt)
Updated :math:\mathbf{C} matrix.

References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
if nx is None:
nx = get_backend(p, *T, *Cs)

# Correct order mistake in Equation 14 in [12]
tmpsum = sum([
lambdas[s] * nx.dot(
nx.dot(T[s], Cs[s]),
T[s].T
) for s in range(len(T))
])
ppt = nx.outer(p, p)

return tmpsum / ppt

[docs]
def update_kl_loss(p, lambdas, T, Cs, nx=None):
r"""
Updates :math:\mathbf{C} according to the KL Loss kernel with the S
:math:\mathbf{T}_s couplings calculated at each iteration of the GW
barycenter problem in :ref:[12]:

.. math::

\mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)

Where :

- :math:\mathbf{C}_s: metric cost matrix
- :math:\mathbf{p}_s: distribution

Parameters
----------
p  : array-like, shape (N,)
Weights in the targeted barycenter.
lambdas : list of float
List of the S spaces' weights
T : list of S array-like of shape (N, ns)
The S :math:\mathbf{T}_s couplings calculated at each iteration.
Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
----------
C : array-like, shape (ns, ns)
updated :math:\mathbf{C} matrix

References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
if nx is None:
nx = get_backend(p, *T, *Cs)

# Correct order mistake in Equation 15 in [12]
tmpsum = sum([
lambdas[s] * nx.dot(
nx.dot(T[s], nx.log(nx.maximum(Cs[s], 1e-15))),
T[s].T
) for s in range(len(T))
])
ppt = nx.outer(p, p)

return nx.exp(tmpsum / ppt)

[docs]
def update_feature_matrix(lambdas, Ys, Ts, p, nx=None):
r"""Updates the feature with respect to the S :math:\mathbf{T}_s couplings.

See "Solving the barycenter problem with Block Coordinate Descent (BCD)"
in :ref:[24] <references-update-feature-matrix> calculated at each iteration

Parameters
----------
p : array-like, shape (N,)
masses in the targeted barycenter
lambdas : list of float
List of the S spaces' weights
Ts : list of S array-like, shape (N, ns)
The S :math:\mathbf{T}_s couplings calculated at each iteration
Ys : list of S array-like, shape (d,ns)
The features.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
X : array-like, shape (d, N)

.. _references-update-feature-matrix:
References
----------
.. [24] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
if nx is None:
nx = get_backend(*Ys, *Ts, p)

p = 1. / p
tmpsum = sum([
lambdas[s] * nx.dot(Ys[s], Ts[s].T) * p[None, :]
for s in range(len(Ts))
])
return tmpsum

[docs]
def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None):
r"""Return loss matrices and tensors for semi-relaxed Gromov-Wasserstein fast computation

Returns the value of :math:\mathcal{L}(\mathbf{C_1}, \mathbf{C_2}) \otimes \mathbf{T} with the
selected loss function as the loss function of semi-relaxed Gromov-Wasserstein discrepancy.

The matrices are computed as described in Proposition 1 in :ref:[12] <references-init-matrix>
and adapted to the semi-relaxed problem where the second marginal is not a constant anymore.

Where :

- :math:\mathbf{C_1}: Metric cost matrix in the source space
- :math:\mathbf{C_2}: Metric cost matrix in the target space
- :math:\mathbf{T}: A coupling between those two spaces

The square-loss function :math:L(a, b) = |a - b|^2 is read as :

.. math::

L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)

\mathrm{with} \ f_1(a) &= a^2

f_2(b) &= b^2

h_1(a) &= a

h_2(b) &= 2b

The kl-loss function :math:L(a, b) = a \log\left(\frac{a}{b}\right) - a + b is read as :

.. math::

L(a, b) = f_1(a) + f_2(b) - h_1(a) h_2(b)

\mathrm{with} \ f_1(a) &= a \log(a) - a

f_2(b) &= b

h_1(a) &= a

h_2(b) &= \log(b)
Parameters
----------
C1 : array-like, shape (ns, ns)
Metric cost matrix in the source space
C2 : array-like, shape (nt, nt)
Metric cost matrix in the target space
p : array-like, shape (ns,)
loss_fun : str, optional
Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss')
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
constC : array-like, shape (ns, nt)
Constant :math:\mathbf{C} matrix in Eq. (6) adapted to srGW
hC1 : array-like, shape (ns, ns)
:math:\mathbf{h1}(\mathbf{C1}) matrix in Eq. (6)
hC2 : array-like, shape (nt, nt)
:math:\mathbf{h2}(\mathbf{C2}) matrix in Eq. (6)
fC2t: array-like, shape (nt, nt)
:math:\mathbf{f2}(\mathbf{C2})^\top matrix in Eq. (6)

.. _references-init-matrix:
References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

.. [48]  Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2022.
"""
if nx is None:
C1, C2, p = list_to_array(C1, C2, p)
nx = get_backend(C1, C2, p)

if loss_fun == 'square_loss':
def f1(a):
return (a**2)

def f2(b):
return (b**2)

def h1(a):
return a

def h2(b):
return 2 * b
elif loss_fun == 'kl_loss':
def f1(a):
return a * nx.log(a + 1e-15) - a

def f2(b):
return b

def h1(a):
return a

def h2(b):
return nx.log(b + 1e-15)
else:
raise ValueError(f"Unknown loss_fun='{loss_fun}'. Use one of: {'square_loss', 'kl_loss'}.")

constC = nx.dot(nx.dot(f1(C1), nx.reshape(p, (-1, 1))),
nx.ones((1, C2.shape[0]), type_as=p))

hC1 = h1(C1)
hC2 = h2(C2)
fC2t = f2(C2).T
return constC, hC1, hC2, fC2t