Computing d-dimensional Barycenters via d-MMOT

When the cost is discretized (Monge), the d-MMOT solver can more quickly compute and minimize the distance between many distributions without the need for intermediate barycenter computations. This example compares the time to identify, and the quality of, solutions for the d-MMOT problem using a primal/dual algorithm and classical LP barycenter approaches.

# Author: Ronak Mehta <ronakrm@cs.wisc.edu>
#         Xizheng Yu <xyu354@wisc.edu>
#
# License: MIT License

Generating 2 distributions

import numpy as np
import matplotlib.pyplot as pl
import ot

np.random.seed(0)

n = 100
d = 2
# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)  # m=mean, s=std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
A = np.vstack((a1, a2)).T
x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')

pl.figure(1, figsize=(6.4, 3))
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.legend()
plot dmmot
<matplotlib.legend.Legend object at 0x7f49fbeed000>

Minimize the distances among distributions, identify the Barycenter

The objective being minimized is different for both methods, so the objective values cannot be compared.

# L2 Iteration
weights = np.ones(d) / d
l2_bary = A.dot(weights)

print('LP Iterations:')
weights = np.ones(d) / d
lp_bary, lp_log = ot.lp.barycenter(
    A, M, weights, solver='interior-point', verbose=False, log=True)
print('Time\t: ', ot.toc(''))
print('Obj\t: ', lp_log['fun'])

print('')
print('Discrete MMOT Algorithm:')
ot.tic()
barys, log = ot.lp.dmmot_monge_1dgrid_optimize(
    A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True)
dmmot_obj = log['primal objective']
print('Time\t: ', ot.toc(''))
print('Obj\t: ', dmmot_obj)
LP Iterations:
/home/circleci/project/ot/lp/cvx.py:125: OptimizeWarning: Sparse constraint matrix detected; setting 'sparse':True.
  sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,

Time    :  306.7678418159485
Obj     :  19.999774740910517

Discrete MMOT Algorithm:
Inital:         Obj:    39.9995 GradNorm:       739.7831
Iter  0:        Obj:    39.9995 GradNorm:       739.7831
Iter 100:       Obj:    2.0914  GradNorm:       180.6322
Iter 200:       Obj:    1.0583  GradNorm:       434.3777
Iter 300:       Obj:    0.4220  GradNorm:       252.9269
Iter 400:       Obj:    0.2317  GradNorm:       168.8668
Iter 500:       Obj:    0.2116  GradNorm:       384.2968
Iter 600:       Obj:    0.1755  GradNorm:       647.6758
Iter 700:       Obj:    0.1343  GradNorm:       786.2442
Iter 800:       Obj:    0.1021  GradNorm:       810.3703
Iter 900:       Obj:    0.0662  GradNorm:       810.3703
Iter 1000:      Obj:    0.0539  GradNorm:       741.7304
Iter 1100:      Obj:    0.0348  GradNorm:       621.4660
Iter 1200:      Obj:    0.0338  GradNorm:       764.3429
Iter 1300:      Obj:    0.0200  GradNorm:       556.2338
Iter 1400:      Obj:    0.0182  GradNorm:       765.8329
Iter 1500:      Obj:    0.0103  GradNorm:       579.8241
Iter 1600:      Obj:    0.0075  GradNorm:       638.2570
Iter 1700:      Obj:    0.0045  GradNorm:       320.1562
Iter 1800:      Obj:    0.0035  GradNorm:       479.8625
Iter 1900:      Obj:    0.0032  GradNorm:       647.1939
Iter 2000:      Obj:    0.0022  GradNorm:       442.4975
Iter 2100:      Obj:    0.0015  GradNorm:       61.0901
Iter 2200:      Obj:    0.0016  GradNorm:       464.9430
Iter 2300:      Obj:    0.0014  GradNorm:       382.5650
Iter 2400:      Obj:    0.0011  GradNorm:       287.2281
Iter 2500:      Obj:    0.0011  GradNorm:       355.6796
Iter 2600:      Obj:    0.0010  GradNorm:       280.1357
Iter 2700:      Obj:    0.0010  GradNorm:       289.6964
Iter 2800:      Obj:    0.0010  GradNorm:       184.4234
Iter 2900:      Obj:    0.0009  GradNorm:       246.5847
Iter 3000:      Obj:    0.0009  GradNorm:       65.3299
Iter 3100:      Obj:    0.0009  GradNorm:       185.9355
Iter 3200:      Obj:    0.0009  GradNorm:       263.0209
Iter 3300:      Obj:    0.0009  GradNorm:       300.3132
Iter 3400:      Obj:    0.0009  GradNorm:       231.4044
Iter 3500:      Obj:    0.0009  GradNorm:       226.3184
Iter 3600:      Obj:    0.0009  GradNorm:       211.4237
Iter 3700:      Obj:    0.0009  GradNorm:       233.2981
Iter 3800:      Obj:    0.0009  GradNorm:       299.0853
Iter 3900:      Obj:    0.0009  GradNorm:       262.4271

Time    :  4.13730001449585
Obj     :  0.0008940778156514197

Compare Barycenters in both methods

pl.figure(1, figsize=(6.4, 3))
for i in range(len(barys)):
    if i == 0:
        pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
    else:
        continue
        # pl.plot(x, barys[i], 'g-*')
pl.plot(x, lp_bary, label='LP Barycenter')
pl.plot(x, l2_bary, label='L2 Barycenter')
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver')
pl.legend()
Monge Cost: Barycenters from LP Solver and dmmot solver
<matplotlib.legend.Legend object at 0x7f49fa0783a0>

More than 2 distributions

Generate 7 pseudorandom gaussian distributions with 50 bins.

n = 50  # nb bins
d = 7
vecsize = n * d

data = []
for i in range(d):
    m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1)
    a = ot.datasets.make_1D_gauss(n, m=m, s=5)
    data.append(a)

x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
A = np.vstack(data).T

pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
    pl.plot(x, data[i])

pl.title('Distributions')
pl.legend()
Distributions
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.

<matplotlib.legend.Legend object at 0x7f49dddcbf70>

Minimizing Distances Among Many Distributions

The objective being minimized is different for both methods, so the objective values cannot be compared.

# Perform gradient descent optimization using the d-MMOT method.
barys = ot.lp.dmmot_monge_1dgrid_optimize(
    A, niters=3000, lr_init=1e-4, lr_decay=0.997)

# after minimization, any distribution can be used as a estimate of barycenter.
bary = barys[0]

# Compute 1D Wasserstein barycenter using the L2/LP method
weights = ot.unif(d)
l2_bary = A.dot(weights)
lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point',
                                     verbose=False, log=True)
Inital:         Obj:    37.1964 GradNorm:       284.3413
Iter  0:        Obj:    37.1964 GradNorm:       280.9858
Iter 100:       Obj:    3.3320  GradNorm:       136.2204
Iter 200:       Obj:    0.7755  GradNorm:       156.3650
Iter 300:       Obj:    0.4874  GradNorm:       229.6258
Iter 400:       Obj:    0.3684  GradNorm:       238.1008
Iter 500:       Obj:    0.3353  GradNorm:       280.3034
Iter 600:       Obj:    0.2220  GradNorm:       267.4771
Iter 700:       Obj:    0.1678  GradNorm:       284.3413
Iter 800:       Obj:    0.1315  GradNorm:       284.3413
Iter 900:       Obj:    0.0706  GradNorm:       271.7241
Iter 1000:      Obj:    0.0567  GradNorm:       269.5960
Iter 1100:      Obj:    0.0420  GradNorm:       250.8386
Iter 1200:      Obj:    0.0345  GradNorm:       271.8676
Iter 1300:      Obj:    0.0230  GradNorm:       230.8679
Iter 1400:      Obj:    0.0145  GradNorm:       217.5960
Iter 1500:      Obj:    0.0122  GradNorm:       244.9326
Iter 1600:      Obj:    0.0089  GradNorm:       207.8076
Iter 1700:      Obj:    0.0064  GradNorm:       175.3682
Iter 1800:      Obj:    0.0054  GradNorm:       208.5713
Iter 1900:      Obj:    0.0042  GradNorm:       196.3110
Iter 2000:      Obj:    0.0035  GradNorm:       189.1930
Iter 2100:      Obj:    0.0028  GradNorm:       152.9444
Iter 2200:      Obj:    0.0025  GradNorm:       154.9903
Iter 2300:      Obj:    0.0023  GradNorm:       165.5778
Iter 2400:      Obj:    0.0020  GradNorm:       162.6161
Iter 2500:      Obj:    0.0019  GradNorm:       148.9564
Iter 2600:      Obj:    0.0018  GradNorm:       150.7780
Iter 2700:      Obj:    0.0017  GradNorm:       160.8478
Iter 2800:      Obj:    0.0017  GradNorm:       145.0931
Iter 2900:      Obj:    0.0016  GradNorm:       128.1718

Compare Barycenters in both methods

pl.figure(1, figsize=(6.4, 3))
pl.plot(x, bary, 'g-*', label='Discrete MMOT')
pl.plot(x, l2_bary, 'k', label='L2 Barycenter')
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein')
pl.title('Barycenters')
pl.legend()
Barycenters
<matplotlib.legend.Legend object at 0x7f49fa11c520>

Compare with original distributions

pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
    pl.plot(x, data[i])
for i in range(len(barys)):
    if i == 0:
        pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
    else:
        continue
        # pl.plot(x, barys[i], 'g')
pl.plot(x, l2_bary, 'k^', label='L2')
pl.plot(x, lp_bary, 'o', color='grey', label='LP')
pl.title('Barycenters')
pl.legend()
pl.show()
Barycenters

Total running time of the script: (0 minutes 33.678 seconds)

Gallery generated by Sphinx-Gallery