Source code for ot.bregman._barycenter

# -*- coding: utf-8 -*-
"""
Bregman projections solvers for entropic regularized wasserstein barycenters
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Nicolas Courty <ncourty@irisa.fr>
#         Hicham Janati <hicham.janati100@gmail.com>
#         Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
#
# License: MIT License

import warnings
import numpy as np

from ..utils import dist, list_to_array, unif
from ..backend import get_backend

from ._utils import geometricBar, geometricMean, projR, projC
from ._sinkhorn import sinkhorn


[docs] def barycenter( A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs, ): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter>` Parameters ---------- A : array-like, shape (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. Returns ------- a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-barycenter: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ if method.lower() == "sinkhorn": return barycenter_sinkhorn( A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) elif method.lower() == "sinkhorn_stabilized": return barycenter_stabilized( A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) elif method.lower() == "sinkhorn_log": return _barycenter_sinkhorn_log( A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) else: raise ValueError("Unknown method '%s'." % method)
[docs] def barycenter_sinkhorn( A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True, ): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3]<references-barycenter-sinkhorn>`. Parameters ---------- A : array-like, shape (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. Returns ------- a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-barycenter-sinkhorn: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ A, M = list_to_array(A, M) nx = get_backend(A, M) if weights is None: weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] else: assert len(weights) == A.shape[1] if log: log = {"err": []} K = nx.exp(-M / reg) err = 1 UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) u = (geometricMean(UKv) / UKv.T).T for ii in range(numItermax): UKv = u * nx.dot(K.T, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv if ii % 10 == 1: err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: log["err"].append(err) if err < stopThr: break if verbose: if ii % 200 == 0: print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) print("{:5d}|{:8e}|".format(ii, err)) else: if warn: warnings.warn( "Sinkhorn did not converge. You might want to " "increase the number of iterations `numItermax` " "or the regularization parameter `reg`." ) if log: log["niter"] = ii return geometricBar(weights, UKv), log else: return geometricBar(weights, UKv)
[docs] def free_support_sinkhorn_barycenter( measures_locations, measures_weights, X_init, reg, b=None, weights=None, numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, **kwargs, ): r""" Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: .. math:: \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) where : - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). There are two differences with the following codes: - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of :ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting. - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2). Parameters ---------- measures_locations : list of N (k_i,d) array-like The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space (:math:`k_i` can be different for each element of the list) measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter reg : float Regularization term >0 b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional Max number of iterations numInnerItermax : int, optional Max number of iterations when calculating the transport plans with Sinkhorn stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- X : (k,d) array-like Support locations (on k atoms) of the barycenter See Also -------- ot.bregman.sinkhorn : Entropic regularized OT solver ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming .. _references-free-support-barycenter: References ---------- .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. """ nx = get_backend(*measures_locations, *measures_weights, X_init) iter_count = 0 N = len(measures_locations) k = X_init.shape[0] d = X_init.shape[1] if b is None: b = nx.ones((k,), type_as=X_init) / k if weights is None: weights = nx.ones((N,), type_as=X_init) / N X = X_init log_dict = {} displacement_square_norms = [] displacement_square_norm = stopThr + 1.0 while displacement_square_norm > stopThr and iter_count < numItermax: T_sum = nx.zeros((k, d), type_as=X_init) for measure_locations_i, measure_weights_i, weight_i in zip( measures_locations, measures_weights, weights ): M_i = dist(X, measure_locations_i) T_i = sinkhorn( b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs ) T_sum = T_sum + weight_i * 1.0 / b[:, None] * nx.dot( T_i, measure_locations_i ) displacement_square_norm = nx.sum((T_sum - X) ** 2) if log: displacement_square_norms.append(displacement_square_norm) X = T_sum if verbose: print( "iteration %d, displacement_square_norm=%f\n", iter_count, displacement_square_norm, ) iter_count += 1 if log: log_dict["displacement_square_norms"] = displacement_square_norms return X, log_dict else: return X
def _barycenter_sinkhorn_log( A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True, ): r"""Compute the entropic wasserstein barycenter in log-domain""" A, M = list_to_array(A, M) dim, n_hists = A.shape nx = get_backend(A, M) if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented" " for Jax and tf. Use numpy or torch arrays instead." ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: assert len(weights) == A.shape[1] if log: log = {"err": []} M = -M / reg logA = nx.log(A + 1e-16) log_KU, G = nx.zeros((2, *logA.shape), type_as=A) err = 1 for ii in range(numItermax): log_bar = nx.zeros(dim, type_as=A) for k in range(n_hists): f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) log_bar = log_bar + weights[k] * log_KU[:, k] if ii % 10 == 1: err = nx.exp(G + log_KU).std(axis=1).sum() # log and verbose print if log: log["err"].append(err) if err < stopThr: break if verbose: if ii % 200 == 0: print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) print("{:5d}|{:8e}|".format(ii, err)) G = log_bar[:, None] - log_KU else: if warn: warnings.warn( "Sinkhorn did not converge. You might want to " "increase the number of iterations `numItermax` " "or the regularization parameter `reg`." ) if log: log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar)
[docs] def barycenter_stabilized( A, M, reg, tau=1e10, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True, ): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>` Parameters ---------- A : array-like, shape (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 tau : float threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. Returns ------- a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-barycenter-stabilized: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ A, M = list_to_array(A, M) nx = get_backend(A, M) dim, n_hists = A.shape if weights is None: weights = nx.ones((n_hists,), type_as=M) / n_hists else: assert len(weights) == A.shape[1] if log: log = {"err": []} u = nx.ones((dim, n_hists), type_as=M) / dim v = nx.ones((dim, n_hists), type_as=M) / dim K = nx.exp(-M / reg) err = 1.0 alpha = nx.zeros((dim,), type_as=M) beta = nx.zeros((dim,), type_as=M) q = nx.ones((dim,), type_as=M) / dim for ii in range(numItermax): qprev = q Kv = nx.dot(K, v) u = A / Kv Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] v = Q / Ktu absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True alpha += reg * nx.log(nx.max(u, 1)) beta += reg * nx.log(nx.max(v, 1)) K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) v = nx.ones(tuple(v.shape), type_as=v) Kv = nx.dot(K, v) if ( nx.any(Ktu == 0.0) or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v)) ): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn("Numerical errors at iteration %s" % ii) q = qprev break if (ii % 10 == 0 and not absorbing) or ii == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(u * Kv - A)) if log: log["err"].append(err) if err < stopThr: break if verbose: if ii % 50 == 0: print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) print("{:5d}|{:8e}|".format(ii, err)) else: if warn: warnings.warn( "Stabilized Sinkhorn did not converge." + "Try a larger entropy `reg`" + "Or a larger absorption threshold `tau`." ) if log: log["niter"] = ii log["logu"] = nx.log(u + 1e-16) log["logv"] = nx.log(v + 1e-16) return q, log else: return q
[docs] def barycenter_debiased( A, M, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs, ): r"""Compute the debiased Sinkhorn barycenter of distributions A The function solves the following optimization problem: .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT The algorithm used for solving the problem is the debiased Sinkhorn algorithm as proposed in :ref:`[37] <references-barycenter-debiased>` Parameters ---------- A : array-like, shape (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` M : array-like, shape (dim, dim) loss matrix for OT reg : float Regularization term > 0 method : str (optional) method used for the solver either 'sinkhorn' or 'sinkhorn_log' weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coordinates) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. Returns ------- a : (dim,) array-like Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-barycenter-debiased: References ---------- .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ if method.lower() == "sinkhorn": return _barycenter_debiased( A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) elif method.lower() == "sinkhorn_log": return _barycenter_debiased_log( A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) else: raise ValueError("Unknown method '%s'." % method)
def _barycenter_debiased( A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True, ): r"""Compute the debiased sinkhorn barycenter of distributions A.""" A, M = list_to_array(A, M) nx = get_backend(A, M) if weights is None: weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] else: assert len(weights) == A.shape[1] if log: log = {"err": []} K = nx.exp(-M / reg) err = 1 UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) u = (geometricMean(UKv) / UKv.T).T c = nx.ones(A.shape[0], type_as=A) bar = nx.ones(A.shape[0], type_as=A) for ii in range(numItermax): bold = bar UKv = nx.dot(K, A / nx.dot(K, u)) bar = c * geometricBar(weights, UKv) u = bar[:, None] / UKv c = (c * bar / nx.dot(K, c)) ** 0.5 if ii % 10 == 9: err = abs(bar - bold).max() / max(bar.max(), 1.0) # log and verbose print if log: log["err"].append(err) # debiased Sinkhorn does not converge monotonically # guarantee a few iterations are done before stopping if err < stopThr and ii > 20: break if verbose: if ii % 200 == 0: print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) print("{:5d}|{:8e}|".format(ii, err)) else: if warn: warnings.warn( "Sinkhorn did not converge. You might want to " "increase the number of iterations `numItermax` " "or the regularization parameter `reg`." ) if log: log["niter"] = ii return bar, log else: return bar def _barycenter_debiased_log( A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False, warn=True, ): r"""Compute the debiased sinkhorn barycenter in log domain.""" A, M = list_to_array(A, M) dim, n_hists = A.shape nx = get_backend(A, M) if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented" " for Jax and TF. Use numpy or torch arrays instead." ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: assert len(weights) == A.shape[1] if log: log = {"err": []} M = -M / reg logA = nx.log(A + 1e-16) log_KU, G = nx.zeros((2, *logA.shape), type_as=A) c = nx.zeros(dim, type_as=A) err = 1 for ii in range(numItermax): log_bar = nx.zeros(dim, type_as=A) for k in range(n_hists): f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) log_bar += weights[k] * log_KU[:, k] log_bar += c if ii % 10 == 1: err = nx.exp(G + log_KU).std(axis=1).sum() # log and verbose print if log: log["err"].append(err) if err < stopThr and ii > 20: break if verbose: if ii % 200 == 0: print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) print("{:5d}|{:8e}|".format(ii, err)) G = log_bar[:, None] - log_KU for _ in range(10): c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) else: if warn: warnings.warn( "Sinkhorn did not converge. You might want to " "increase the number of iterations `numItermax` " "or the regularization parameter `reg`." ) if log: log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar)
[docs] def jcpot_barycenter( Xs, Ys, Xt, reg, metric="sqeuclidean", numItermax=100, stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs, ): r"""Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] <references-jcpot-barycenter>` The function solves the following optimization problem: .. math:: \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} where : - :math:`\lambda_k` is the weight of `k`-th source domain - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)` The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- Xs : list of K array-like(nsk,d) features of all source domains' samples Ys : list of K array-like(nsk,) labels of all source domains' samples Xt : array-like (nt,d) samples in the target domain reg : float Regularization term > 0 metric : string, optional (default="sqeuclidean") The ground metric for the Wasserstein problem numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on relative change in the barycenter (>0) verbose : bool, optional (default=False) Controls the verbosity of the optimization algorithm log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. Returns ------- h : (C,) array-like proportion estimation in the target domain log : dict log dictionary return only if log==True in parameters .. _references-jcpot-barycenter: References ---------- .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ Xs = list_to_array(*Xs) Ys = list_to_array(*Ys) Xt = list_to_array(Xt) nx = get_backend(*Xs, *Ys, Xt) nbclasses = len(nx.unique(Ys[0])) nbdomains = len(Xs) # log dictionary if log: log = {"niter": 0, "err": [], "M": [], "D1": [], "D2": [], "gamma": []} K = [] M = [] D1 = [] D2 = [] # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 for d in range(nbdomains): dom = {} nsk = Xs[d].shape[0] # get number of elements for this domain dom["nbelem"] = nsk classes = nx.unique(Ys[d]) # get number of classes for this domain # format classes to start from 0 for convenience if nx.min(classes) != 0: Ys[d] -= nx.min(classes) classes = nx.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices Dtmp1 = np.zeros((nbclasses, nsk)) Dtmp2 = np.zeros((nbclasses, nsk)) for c in classes: nbelemperclass = float(nx.sum(Ys[d] == c)) if nbelemperclass != 0: Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1.0 Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1.0 / (nbelemperclass) D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0])) D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0])) # build the cost matrix and the Gibbs kernel Mtmp = dist(Xs[d], Xt, metric=metric) M.append(Mtmp) Ktmp = nx.exp(-Mtmp / reg) K.append(Ktmp) # uniform target distribution a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) err = 1 old_bary = nx.ones((nbclasses,), type_as=Xs[0]) for ii in range(numItermax): bary = nx.zeros((nbclasses,), type_as=Xs[0]) # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): K[d] = projC(K[d], a) other = nx.sum(K[d], axis=1) bary += nx.log(nx.dot(D1[d], other)) / nbdomains bary = nx.exp(bary) # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): new = nx.dot(D2[d].T, bary) K[d] = projR(K[d], new) err = nx.norm(bary - old_bary) old_bary = bary if log: log["err"].append(err) if err < stopThr: break if verbose: if ii % 200 == 0: print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) print("{:5d}|{:8e}|".format(ii, err)) else: if warn: warnings.warn( "Algorithm did not converge. You might want to " "increase the number of iterations `numItermax` " "or the regularization parameter `reg`." ) bary = bary / nx.sum(bary) if log: log["niter"] = ii log["M"] = M log["D1"] = D1 log["D2"] = D2 log["gamma"] = K return bary, log else: return bary