# -*- coding: utf-8 -*-
"""
d-MMOT solvers for optimal transport
"""
# Author: Ronak Mehta <ronakrm@cs.wisc.edu>
# Xizheng Yu <xyu354@wisc.edu>
#
# License: MIT License
import numpy as np
from ..backend import get_backend
def dist_monge_max_min(i):
r"""
A tensor :math:c is Monge if for all valid :math:i_1, \ldots i_d and
:math:j_1, \ldots, j_d,
.. math::
c(s_1, \ldots, s_d) + c(t_1, \ldots t_d) \leq c(i_1, \ldots i_d) +
c(j_1, \ldots, j_d)
where :math:s_k = \min(i_k, j_k) and :math:t_k = \max(i_k, j_k).
Our focus is on a specific cost, which is known to be Monge:
.. math::
c(i_1,i_2,\ldots,i_d) = \max{i_k:k\in[d]} - \min{i_k:k\in[d]}.
When :math:d=2, this cost reduces to :math:c(i_1,i_2)=|i_1-i_2|,
which agrees with the classical EMD cost. This choice of :math:c is called
the generalized EMD cost.
Parameters
----------
i : list
The list of integer indexes.
Returns
-------
cost : numeric value
The ground cost (generalized EMD cost) of the tensor.
References
----------
.. [56] Jeffery Kline. Properties of the d-dimensional earth mover's
problem. Discrete Applied Mathematics, 265: 128-141, 2019.
.. [57] Wolfgang W. Bein, Peter Brucker, James K. Park, and Pramod K.
Pathak. A monge property for the d-dimensional transportation problem.
Discrete Applied Mathematics, 58(2):97-109, 1995. ISSN 0166-218X. doi:
https://doi.org/10.1016/0166-218X(93)E0121-E. URL
https://www.sciencedirect.com/ science/article/pii/0166218X93E0121E.
Workshop on Discrete Algorithms.
"""
return max(i) - min(i)
[docs]
def dmmot_monge_1dgrid_loss(A, verbose=False, log=False):
r"""
Compute the discrete multi-marginal optimal transport of distributions A.
This function operates on distributions whose supports are real numbers on
the real line.
The algorithm solves both primal and dual d-MMOT programs concurrently to
produce the optimal transport plan as well as the total (minimal) cost.
The cost is a ground cost, and the solution is independent of
which Monge cost is desired.
The algorithm accepts :math:`d` distributions (i.e., histograms)
:math:`a_{1}, \ldots, a_{d} \in \mathbb{R}_{+}^{n}` with :math:`e^{\prime}
a_{j}=1` for all :math:`j \in[d]`. Although the algorithm states that all
histograms have the same number of bins, the algorithm can be easily
adapted to accept as inputs :math:`a_{i} \in \mathbb{R}_{+}^{n_{i}}`
with :math:`n_{i} \neq n_{j}` [50].
The function solves the following optimization problem[51]:
.. math::
\begin{align}\begin{aligned}
\underset{\gamma\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}}
\sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, \gamma(i_1,\ldots,i_d)
\quad \textrm{s.t.}
\sum_{i_2,\ldots,i_d} \gamma(i_1,\ldots,i_d) &= a_1(i_i),
(\forall i_1\in[n])\\
\qquad\vdots\\
\sum_{i_1,\ldots,i_{d-1}} \gamma(i_1,\ldots,i_d) &= a_{d}(i_{d}),
(\forall i_d\in[n]).
\end{aligned}
\end{align}
Parameters
----------
A : nx.ndarray, shape (dim, n_hists)
The input ndarray containing distributions of n bins in d dimensions.
verbose : bool, optional
If True, print debugging information during execution. Default=False.
log : bool, optional
If True, record log. Default is False.
Returns
-------
obj : float
the value of the primal objective function evaluated at the solution.
log : dict
A dictionary containing the log of the discrete mmot problem:
- 'A': a dictionary that maps tuples of indices to the corresponding
primal variables. The tuples are the indices of the entries that are
set to their minimum value during the algorithm.
- 'primal objective': a float, the value of the objective function
evaluated at the solution.
- 'dual': a list of arrays, the dual variables corresponding to
the input arrays. The i-th element of the list is the dual variable
corresponding to the i-th dimension of the input arrays.
- 'dual objective': a float, the value of the dual objective function
evaluated at the solution.
References
----------
.. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, &
Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal
Transport Regularization. In The Eleventh International
Conference on Learning Representations.
.. [56] Jeffery Kline. Properties of the d-dimensional earth mover's
problem. Discrete Applied Mathematics, 265: 128-141, 2019.
.. [58] Leonid V Kantorovich. On the translocation of masses. Dokl. Akad.
Nauk SSSR, 37:227-229, 1942.
See Also
--------
ot.lp.dmmot_monge_1dgrid_optimize : Optimize the d-Dimensional Earth
Mover's Distance (d-MMOT)
"""
nx = get_backend(A)
A_copy = A
A = nx.to_numpy(A)
AA = [np.copy(A[:, j]) for j in range(A.shape[1])]
dims = tuple([len(_) for _ in AA])
xx = {}
dual = [np.zeros(d) for d in dims]
idx = [
0,
] * len(AA)
obj = 0
if verbose:
print("i minval oldidx\t\tobj\t\tvals")
while all([i < _ for _, i in zip(dims, idx)]):
vals = [v[i] for v, i in zip(AA, idx)]
minval = min(vals)
i = vals.index(minval)
xx[tuple(idx)] = minval
obj += (dist_monge_max_min(idx)) * minval
for v, j in zip(AA, idx):
v[j] -= minval
# oldidx = nx.copy(idx)
oldidx = idx.copy()
idx[i] += 1
if idx[i] < dims[i]:
temp = (
dist_monge_max_min(idx)
- dist_monge_max_min(oldidx)
+ dual[i][idx[i] - 1]
)
dual[i][idx[i]] += temp
if verbose:
print(i, minval, oldidx, obj, "\t", vals)
# the above terminates when any entry in idx equals the corresponding
# value in dims this leaves other dimensions incomplete; the remaining
# terms of the dual solution must be filled-in
for _, i in enumerate(idx):
try:
dual[_][i:] = dual[_][i]
except Exception:
pass
dualobj = sum([np.dot(A[:, i], arr) for i, arr in enumerate(dual)])
obj = nx.from_numpy(obj)
log_dict = {
"A": xx,
"primal objective": obj,
"dual": dual,
"dual objective": dualobj,
}
# define forward/backward relations for pytorch
obj = nx.set_gradients(obj, (A_copy), (dual))
if log:
return obj, log_dict
else:
return obj
[docs]
def dmmot_monge_1dgrid_optimize(
A,
niters=100,
lr_init=1e-5,
lr_decay=0.995,
print_rate=100,
verbose=False,
log=False,
):
r"""Minimize the d-dimensional EMD using gradient descent.
Discrete Multi-Marginal Optimal Transport (d-MMOT): Let :math:`a_1, \ldots,
a_d\in\mathbb{R}^n_{+}` be discrete probability distributions. Here,
the d-MMOT is the LP,
.. math::
\begin{align}\begin{aligned}
\underset{x\in\mathbb{R}^{n^{d}}_{+}} {\textrm{min}}
\sum_{i_1,\ldots,i_d} c(i_1,\ldots, i_d)\, x(i_1,\ldots,i_d) \quad
\textrm{s.t.}
\sum_{i_2,\ldots,i_d} x(i_1,\ldots,i_d) &= a_1(i_i),
(\forall i_1\in[n])\\
\qquad\vdots\\
\sum_{i_1,\ldots,i_{d-1}} x(i_1,\ldots,i_d) &= a_{d}(i_{d}),
(\forall i_d\in[n]).
\end{aligned}
\end{align}
The dual linear program of the d-MMOT problem is:
.. math::
\underset{z_j\in\mathbb{R}^n, j\in[d]}{\textrm{maximize}}\qquad\sum_{j}
a_j'z_j\qquad \textrm{subject to}\qquad z_{1}(i_1)+\cdots+z_{d}(i_{d})
\leq c(i_1,\ldots,i_{d}),
where the indices in the constraints include all :math:`i_j\in[n]`, :math:
`j\in[d]`. Denote by :math:`\phi(a_1,\ldots,a_d)`, the optimal objective
value of the LP in d-MMOT problem. Let :math:`z^*` be an optimal solution
to the dual program. Then,
.. math::
\begin{align}
\nabla \phi(a_1,\ldots,a_{d}) &= z^*,
~~\text{and for any $t\in \mathbb{R}$,}~~
\phi(a_1,a_2,\ldots,a_{d}) = \sum_{j}a_j'
(z_j^* + t\, \eta), \nonumber \\
\text{where } \eta &:= (z_1^{*}(n)\,e, z^*_1(n)\,e, \cdots,
z^*_{d}(n)\,e)
\end{align}
Using these dual variables naturally provided by the algorithm in
ot.lp.dmmot_monge_1dgrid_loss, gradient steps move each input distribution
to minimize their d-mmot distance.
Parameters
----------
A : nx.ndarray, shape (dim, n_hists)
The input ndarray containing distributions of n bins in d dimensions.
niters : int, optional (default=100)
The maximum number of iterations for the optimization algorithm.
lr_init : float, optional (default=1e-5)
The initial learning rate (step size) for the optimization algorithm.
lr_decay : float, optional (default=0.995)
The learning rate decay rate in each iteration.
print_rate : int, optional (default=100)
The rate at which to print the objective value and gradient norm
during the optimization algorithm.
verbose : bool, optional
If True, print debugging information during execution. Default=False.
log : bool, optional
If True, record log. Default is False.
Returns
-------
a : list of ndarrays, each of shape (n,)
The optimal solution as a list of n approximate barycenters, each of
length vecsize.
log : dict
log dictionary return only if log==True in parameters
References
----------
.. [55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, &
Vikas Singh (2023). Efficient Discrete Multi Marginal Optimal
Transport Regularization. In The Eleventh International
Conference on Learning Representations.
.. [60] Olvi L Mangasarian and RR Meyer. Nonlinear perturbation of linear
programs. SIAM Journal on Control and Optimization, 17(6):745-752, 1979
.. [59] Michael C Ferris and Olvi L Mangasarian. Finite perturbation of
convex programs. Applied Mathematics and Optimization, 23(1):263-273,
1991.
See Also
--------
ot.lp.dmmot_monge_1dgrid_loss: d-Dimensional Earth Mover's Solver
"""
# function body here
nx = get_backend(A)
A = nx.to_numpy(A)
n, d = A.shape # n is dim, d is n_hists
def dualIter(A, lr):
funcval, log_dict = dmmot_monge_1dgrid_loss(A, verbose=verbose, log=True)
grad = np.column_stack(log_dict["dual"])
A_new = np.reshape(A, (n, d)) - grad * lr
return funcval, A_new, grad, log_dict
def renormalize(A):
A = np.reshape(A, (n, d))
for i in range(A.shape[1]):
if min(A[:, i]) < 0:
A[:, i] -= min(A[:, i])
A[:, i] /= np.sum(A[:, i])
return A
def listify(A):
return [A[:, i] for i in range(A.shape[1])]
lr = lr_init
funcval, _, grad, log_dict = dualIter(A, lr)
gn = np.linalg.norm(grad)
print(f"Initial:\t\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}")
for i in range(niters):
A = renormalize(A)
funcval, A, grad, log_dict = dualIter(A, lr)
gn = np.linalg.norm(grad)
if i % print_rate == 0:
print(f"Iter {i:2.0f}:\tObj:\t{funcval:.4f}\tGradNorm:\t{gn:.4f}")
lr *= lr_decay
A = renormalize(A)
a = listify(A)
if log:
return a, log_dict
else:
return a