Source code for ot.gaussian

# -*- coding: utf-8 -*-
"""
Optimal transport for Gaussian distributions
"""

# Author: Theo Gnassounou <theo.gnassounou@inria.fr>
#         Remi Flamary <remi.flamary@polytehnique.edu>
#
# License: MIT License

import warnings

from .backend import get_backend
from .utils import dots, is_all_finite, list_to_array


[docs] def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): r"""Return OT linear operator between samples. The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in :ref:`[2] <references-OT-mapping-linear>`. The linear operator from source to target :math:`M` .. math:: M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} where : .. math:: \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} \Sigma_s^{-1/2} \mathbf{b} &= \mu_t - \mathbf{A} \mu_s Parameters ---------- ms : array-like (d,) mean of the source distribution mt : array-like (d,) mean of the target distribution Cs : array-like (d,d) covariance of the source distribution Ct : array-like (d,d) covariance of the target distribution log : bool, optional record log if True Returns ------- A : (d, d) array-like Linear operator b : (1, d) array-like bias log : dict log dictionary return only if log==True in parameters .. _references-OT-mapping-linear: References ---------- .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) Cs12inv = nx.inv(Cs12) M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) A = dots(Cs12inv, M0, Cs12inv) b = mt - nx.dot(ms, A) if log: log = {} log["Cs12"] = Cs12 log["Cs12inv"] = Cs12inv return A, b, log else: return A, b
[docs] def empirical_bures_wasserstein_mapping( xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False ): r"""Return OT linear operator between samples. The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in :ref:`[1] <references-OT-mapping-linear>` and discussed in remark 2.29 in :ref:`[2] <references-OT-mapping-linear>`. The linear operator from source to target :math:`M` .. math:: M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} where : .. math:: \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} \Sigma_s^{-1/2} \mathbf{b} &= \mu_t - \mathbf{A} \mu_s Parameters ---------- xs : array-like (ns,d) samples in the source domain xt : array-like (nt,d) samples in the target domain reg : float,optional regularization added to the diagonals of covariances (>0) ws : array-like (ns,1), optional weights for the source samples wt : array-like (ns,1), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) log : bool, optional record log if True Returns ------- A : (d, d) array-like Linear operator b : (1, d) array-like bias log : dict log dictionary return only if log==True in parameters .. _references-OT-mapping-linear: References ---------- .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) is_input_finite = is_all_finite(xs, xt) d = xs.shape[1] if ws is None: ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] if bias: mxs = nx.dot(ws.T, xs) / nx.sum(ws) mxt = nx.dot(wt.T, xt) / nx.sum(wt) xs = xs - mxs xt = xt - mxt else: mxs = nx.zeros((1, d), type_as=xs) mxt = nx.zeros((1, d), type_as=xs) Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) if log: A, b, log = bures_wasserstein_mapping(mxs, mxt, Cs, Ct, log=log) else: A, b = bures_wasserstein_mapping(mxs, mxt, Cs, Ct) if is_input_finite and not is_all_finite(A, b): warnings.warn( "Numerical errors were encountered in ot.gaussian.empirical_bures_wasserstein_mapping. " "Consider increasing the regularization parameter `reg`." ) if log: log["Cs"] = Cs log["Ct"] = Ct return A, b, log else: return A, b
[docs] def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): r"""Return Bures Wasserstein distance between samples. The function estimates the Bures-Wasserstein distance between two empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`. The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} where : .. math:: \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) Parameters ---------- ms : array-like (d,) mean of the source distribution mt : array-like (d,) mean of the target distribution Cs : array-like (d,d) covariance of the source distribution Ct : array-like (d,d) covariance of the target distribution log : bool, optional record log if True Returns ------- W : float Bures Wasserstein distance log : dict log dictionary return only if log==True in parameters .. _references-bures-wasserstein-distance: References ---------- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. """ ms, mt, Cs, Ct = list_to_array(ms, mt, Cs, Ct) nx = get_backend(ms, mt, Cs, Ct) Cs12 = nx.sqrtm(Cs) B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) W = nx.sqrt(nx.maximum(nx.norm(ms - mt) ** 2 + B, 0)) if log: log = {} log["Cs12"] = Cs12 return W, log else: return W
[docs] def empirical_bures_wasserstein_distance( xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False ): r"""Return Bures Wasserstein distance from mean and covariance of distribution. The function estimates the Bures-Wasserstein distance between two empirical distributions source :math:`\mu_s` and target :math:`\mu_t`, discussed in remark 2.31 :ref:`[1] <references-bures-wasserstein-distance>`. The Bures Wasserstein distance between source and target distribution :math:`\mathcal{W}` .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} where : .. math:: \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) Parameters ---------- xs : array-like (ns,d) samples in the source domain xt : array-like (nt,d) samples in the target domain reg : float,optional regularization added to the diagonals of covariances (>0) ws : array-like (ns), optional weights for the source samples wt : array-like (ns), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) log : bool, optional record log if True Returns ------- W : float Bures Wasserstein distance log : dict log dictionary return only if log==True in parameters .. _references-bures-wasserstein-distance: References ---------- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) d = xs.shape[1] if ws is None: ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] if bias: mxs = nx.dot(ws.T, xs) / nx.sum(ws) mxt = nx.dot(wt.T, xt) / nx.sum(wt) xs = xs - mxs xt = xt - mxt else: mxs = nx.zeros((1, d), type_as=xs) mxt = nx.zeros((1, d), type_as=xs) Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) if log: W, log = bures_wasserstein_distance(mxs, mxt, Cs, Ct, log=log) log["Cs"] = Cs log["Ct"] = Ct return W, log else: W = bures_wasserstein_distance(mxs, mxt, Cs, Ct) return W
[docs] def bures_wasserstein_barycenter( m, C, weights=None, num_iter=1000, eps=1e-7, log=False ): r"""Return OT linear operator between samples. The function estimates the optimal barycenter of the empirical distributions. This is equivalent to resolving the fixed point algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n` :ref:`[1] <references-OT-mapping-linear-barycenter>`. The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` where : .. math:: \mu_b = \sum_{i=1}^n w_i \mu_i And the barycentric covariance is the solution of the following fixed-point algorithm: .. math:: \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} Parameters ---------- m : array-like (k,d) mean of k distributions C : array-like (k,d,d) covariance of k distributions weights : array-like (k), optional weights for each distribution num_iter : int, optional number of iteration for the fixed point algorithm eps : float, optional tolerance for the fixed point algorithm log : bool, optional record log if True Returns ------- mb : (d,) array-like mean of the barycenter Cb : (d, d) array-like covariance of the barycenter log : dict log dictionary return only if log==True in parameters .. _references-OT-mapping-linear-barycenter: References ---------- .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ nx = get_backend( *C, *m, ) if weights is None: weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0] # Compute the mean barycenter mb = nx.sum(m * weights[:, None], axis=0) # Init the covariance barycenter Cb = nx.mean(C * weights[:, None, None], axis=0) for it in range(num_iter): # fixed point update Cb12 = nx.sqrtm(Cb) Cnew = Cb12 @ C @ Cb12 C_ = [] for i in range(len(C)): C_.append(nx.sqrtm(Cnew[i])) Cnew = nx.stack(C_, axis=0) Cnew *= weights[:, None, None] Cnew = nx.sum(Cnew, axis=0) # check convergence diff = nx.norm(Cb - Cnew) if diff <= eps: break Cb = Cnew else: print("Dit not converge.") if log: log = {} log["num_iter"] = it log["final_diff"] = diff return mb, Cb, log else: return mb, Cb
[docs] def empirical_bures_wasserstein_barycenter( X, reg=1e-6, weights=None, num_iter=1000, eps=1e-7, w=None, bias=True, log=False ): r"""Return OT linear operator between samples. The function estimates the optimal barycenter of the empirical distributions. This is equivalent to resolving the fixed point algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n` :ref:`[1] <references-OT-mapping-linear-barycenter>`. The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)` where : .. math:: \mu_b = \sum_{i=1}^n w_i \mu_i And the barycentric covariance is the solution of the following fixed-point algorithm: .. math:: \Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2} Parameters ---------- X : list of array-like (n,d) samples in each distribution reg : float,optional regularization added to the diagonals of covariances (>0) weights : array-like (n,), optional weights for each distribution num_iter : int, optional number of iteration for the fixed point algorithm eps : float, optional tolerance for the fixed point algorithm w : list of array-like (n,), optional weights for each sample in each distribution bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) log : bool, optional record log if True Returns ------- mb : (d,) array-like mean of the barycenter Cb : (d, d) array-like covariance of the barycenter log : dict log dictionary return only if log==True in parameters .. _references-OT-mapping-linear-barycenter: References ---------- .. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space", SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924, 2011. """ X = list_to_array(*X) nx = get_backend(*X) k = len(X) d = [X[i].shape[1] for i in range(k)] if w is None: w = [ nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k) ] if bias: m = [nx.dot(w[i].T, X[i]) / nx.sum(w[i]) for i in range(k)] X = [X[i] - m[i] for i in range(k)] else: m = [nx.zeros((1, d[i]), type_as=X[i]) for i in range(k)] C = [ nx.dot((X[i] * w[i]).T, X[i]) / nx.sum(w[i]) + reg * nx.eye(d[i], type_as=X[i]) for i in range(k) ] m = nx.stack(m, axis=0) C = nx.stack(C, axis=0) if log: mb, Cb, log = bures_wasserstein_barycenter( m, C, weights=weights, num_iter=num_iter, eps=eps, log=log ) return mb, Cb, log else: mb, Cb = bures_wasserstein_barycenter( m, C, weights=weights, num_iter=num_iter, eps=eps, log=log ) return mb, Cb
[docs] def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False): r"""Return the Gaussian Gromov-Wasserstein value from [57]. This function return the closed form value of the Gaussian Gromov-Wasserstein distance between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` and :math:`\mathcal{N}(\mu_t,\Sigma_t)` when the OT plan is assumed to be also Gaussian. See [57] Theorem 4.1 for more details. Parameters ---------- Cov_s : array-like (ds,ds) covariance of the source distribution Cov_t : array-like (dt,dt) covariance of the target distribution Returns ------- G : float Gaussian Gromov-Wasserstein distance .. _references-gaussien_gromov_wasserstein_distance: References ---------- .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, 59(4), 1178-1198. """ nx = get_backend(Cov_s, Cov_t) # ensure that Cov_s is the largest covariance matrix # that is m >= n if Cov_s.shape[0] < Cov_t.shape[0]: Cov_s, Cov_t = Cov_t, Cov_s n = Cov_t.shape[0] # compte and sort eigenvalues decerasingly d_s = nx.flip(nx.sort(nx.eigh(Cov_s)[0])) d_t = nx.flip(nx.sort(nx.eigh(Cov_t)[0])) # compute the gaussien Gromov-Wasserstein distance res = ( 4 * (nx.sum(d_s) - nx.sum(d_t)) ** 2 + 8 * nx.sum((d_s[:n] - d_t) ** 2) + 8 * nx.sum((d_s[n:]) ** 2) ) if log: log = {} log["d_s"] = d_s log["d_t"] = d_t return nx.sqrt(res), log else: return nx.sqrt(res)
[docs] def empirical_gaussian_gromov_wasserstein_distance(xs, xt, ws=None, wt=None, log=False): r"""Return Gaussian Gromov-Wasserstein distance between samples. The function estimates the Gaussian Gromov-Wasserstein distance between two Gaussien distributions source :math:`\mu_s` and target :math:`\mu_t`, whose parameters are estimated from the provided samples :math:`\mathcal{X}_s` and :math:`\mathcal{X}_t`. See [57] Theorem 4.1 for more details. Parameters ---------- xs : array-like (ns,d) samples in the source domain xt : array-like (nt,d) samples in the target domain ws : array-like (ns,1), optional weights for the source samples wt : array-like (ns,1), optional weights for the target samples log : bool, optional record log if True Returns ------- G : float Gaussian Gromov-Wasserstein distance .. _references-gaussien_gromov_wasserstein: References ---------- .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, 59(4), 1178-1198. """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) if ws is None: ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] mxs = nx.dot(ws.T, xs) / nx.sum(ws) mxt = nx.dot(wt.T, xt) / nx.sum(wt) xs = xs - mxs xt = xt - mxt Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) if log: G, log = gaussian_gromov_wasserstein_distance(Cs, Ct, log=log) log["Cov_s"] = Cs log["Cov_t"] = Ct return G, log else: G = gaussian_gromov_wasserstein_distance(Cs, Ct) return G
[docs] def gaussian_gromov_wasserstein_mapping( mu_s, mu_t, Cov_s, Cov_t, sign_eigs=None, log=False ): r"""Return the Gaussian Gromov-Wasserstein mapping from [57]. This function return the closed form value of the Gaussian Gromov-Wasserstein mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` and :math:`\mathcal{N}(\mu_t,\Sigma_t)` when the OT plan is assumed to be also Gaussian. See [57] Theorem 4.1 for more details. Parameters ---------- mu_s : array-like (ds,) mean of the source distribution mu_t : array-like (dt,) mean of the target distribution Cov_s : array-like (ds,ds) covariance of the source distribution Cov_t : array-like (dt,dt) covariance of the target distribution log : bool, optional record log if True Returns ------- A : (dt, ds) array-like Linear operator b : (1, dt) array-like bias .. _references-gaussien_gromov_wasserstein_mapping: References ---------- .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, 59(4), 1178-1198. """ nx = get_backend(mu_s, mu_t, Cov_s, Cov_t) n = Cov_t.shape[0] m = Cov_s.shape[0] # compte and sort eigenvalues/eigenvectors decreasingly d_s, U_s = nx.eigh(Cov_s) id_s = nx.flip(nx.argsort(d_s)) d_s, U_s = d_s[id_s], U_s[:, id_s] d_t, U_t = nx.eigh(Cov_t) id_t = nx.flip(nx.argsort(d_t)) d_t, U_t = d_t[id_t], U_t[:, id_t] if sign_eigs is None: sign_eigs = nx.ones(min(m, n), type_as=mu_s) if m >= n: A = nx.concatenate( ( nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), nx.zeros((n, m - n), type_as=mu_s), ), axis=1, ).T else: A = nx.concatenate( ( nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s), ), axis=0, ).T A = nx.dot(nx.dot(U_s, A), U_t.T) # compute the gaussien Gromov-Wasserstein dis b = mu_t - nx.dot(mu_s, A) if log: log = {} log["d_s"] = d_s log["d_t"] = d_t log["U_s"] = U_s log["U_t"] = U_t return A, b, log else: return A, b
[docs] def empirical_gaussian_gromov_wasserstein_mapping( xs, xt, ws=None, wt=None, sign_eigs=None, log=False ): r"""Return Gaussian Gromov-Wasserstein mapping between samples. The function estimates the Gaussian Gromov-Wasserstein mapping between two Gaussian distributions source :math:`\mu_s` and target :math:`\mu_t`, whose parameters are estimated from the provided samples :math:`\mathcal{X}_s` and :math:`\mathcal{X}_t`. See [57] Theorem 4.1 for more details. Parameters ---------- xs : array-like (ns,ds) samples in the source domain xt : array-like (nt,dt) samples in the target domain ws : array-like (ns,1), optional weights for the source samples wt : array-like (ns,1), optional weights for the target samples sign_eigs : array-like (min(ds,dt),) or string, optional sign of the eigenvalues of the mapping matrix, by default all signs will be positive. If 'skewness' is provided, the sign of the eigenvalues is selected as the product of the sign of the skewness of the projected data. log : bool, optional record log if True Returns ------- A : (dt, ds) array-like Linear operator b : (1, dt) array-like bias .. _references-empirical_gaussian_gromov_wasserstein_mapping: References ---------- .. [57] Delon, J., Desolneux, A., & Salmona, A. (2022). Gromov–Wasserstein distances between Gaussian distributions. Journal of Applied Probability, 59(4), 1178-1198. """ xs, xt = list_to_array(xs, xt) nx = get_backend(xs, xt) m = xs.shape[1] n = xt.shape[1] if ws is None: ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] # estimate mean and covariance mu_s = nx.dot(ws.T, xs) / nx.sum(ws) mu_t = nx.dot(wt.T, xt) / nx.sum(wt) xs = xs - mu_s xt = xt - mu_t Cov_s = nx.dot((xs * ws).T, xs) / nx.sum(ws) Cov_t = nx.dot((xt * wt).T, xt) / nx.sum(wt) # compute and sort eigenvalues/eigenvectors decreasingly d_s, U_s = nx.eigh(Cov_s) id_s = nx.flip(nx.argsort(d_s)) d_s, U_s = d_s[id_s], U_s[:, id_s] d_t, U_t = nx.eigh(Cov_t) id_t = nx.flip(nx.argsort(d_t)) d_t, U_t = d_t[id_t], U_t[:, id_t] # select the sign of the eigenvalues if sign_eigs is None: sign_eigs = nx.ones(min(m, n), type_as=mu_s) elif sign_eigs == "skewness": size = min(m, n) skew_s = nx.sum((nx.dot(xs, U_s[:, :size])) ** 3 * ws, axis=0) skew_t = nx.sum((nx.dot(xt, U_t[:, :size])) ** 3 * wt, axis=0) sign_eigs = nx.sign(skew_t * skew_s) if m >= n: A = nx.concatenate( ( nx.diag(sign_eigs * nx.sqrt(d_t) / nx.sqrt(d_s[:n])), nx.zeros((n, m - n), type_as=mu_s), ), axis=1, ).T else: A = nx.concatenate( ( nx.diag(sign_eigs * nx.sqrt(d_t[:m]) / nx.sqrt(d_s)), nx.zeros((n - m, m), type_as=mu_s), ), axis=0, ).T A = nx.dot(nx.dot(U_s, A), U_t.T) # compute the gaussien Gromov-Wasserstein dis b = mu_t - nx.dot(mu_s, A) if log: log = {} log["d_s"] = d_s log["d_t"] = d_t log["U_s"] = U_s log["U_t"] = U_t log["Cov_s"] = Cov_s log["Cov_t"] = Cov_t return A, b, log else: return A, b