Source code for ot.bregman._utils

# -*- coding: utf-8 -*-
"""
Common tools of Bregman projections solvers for entropic regularized OT
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

from ..utils import list_to_array
from ..backend import get_backend


[docs] def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" weights, alldistribT = list_to_array(weights, alldistribT) nx = get_backend(weights, alldistribT) assert len(weights) == alldistribT.shape[1] return nx.exp(nx.dot(nx.log(alldistribT), weights.T))
[docs] def geometricMean(alldistribT): """return the geometric mean of distributions""" alldistribT = list_to_array(alldistribT) nx = get_backend(alldistribT) return nx.exp(nx.mean(nx.log(alldistribT), axis=1))
[docs] def projR(gamma, p): """return the KL projection on the row constraints""" gamma, p = list_to_array(gamma, p) nx = get_backend(gamma, p) return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T
[docs] def projC(gamma, q): """return the KL projection on the column constraints""" gamma, q = list_to_array(gamma, q) nx = get_backend(gamma, q) return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)