Note
Go to the end to download the full example code.
Computing d-dimensional Barycenters via d-MMOT
When the cost is discretized (Monge), the d-MMOT solver can more quickly compute and minimize the distance between many distributions without the need for intermediate barycenter computations. This example compares the time to identify, and the quality of, solutions for the d-MMOT problem using a primal/dual algorithm and classical LP barycenter approaches.
# Author: Ronak Mehta <ronakrm@cs.wisc.edu>
# Xizheng Yu <xyu354@wisc.edu>
#
# License: MIT License
Generating 2 distributions
import numpy as np
import matplotlib.pyplot as pl
import ot
np.random.seed(0)
n = 100
d = 2
# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
A = np.vstack((a1, a2)).T
x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
pl.figure(1, figsize=(6.4, 3))
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.legend()
<matplotlib.legend.Legend object at 0x7fba814a1690>
Minimize the distances among distributions, identify the Barycenter
The objective being minimized is different for both methods, so the objective values cannot be compared.
# L2 Iteration
weights = np.ones(d) / d
l2_bary = A.dot(weights)
print('LP Iterations:')
weights = np.ones(d) / d
lp_bary, lp_log = ot.lp.barycenter(
A, M, weights, solver='interior-point', verbose=False, log=True)
print('Time\t: ', ot.toc(''))
print('Obj\t: ', lp_log['fun'])
print('')
print('Discrete MMOT Algorithm:')
ot.tic()
barys, log = ot.lp.dmmot_monge_1dgrid_optimize(
A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True)
dmmot_obj = log['primal objective']
print('Time\t: ', ot.toc(''))
print('Obj\t: ', dmmot_obj)
LP Iterations:
/home/circleci/project/ot/lp/cvx.py:125: OptimizeWarning: Sparse constraint matrix detected; setting 'sparse':True.
sol = sp.optimize.linprog(c, A_eq=A_eq, b_eq=b_eq, method=solver,
Time : 294.8714361190796
Obj : 19.99977474094027
Discrete MMOT Algorithm:
Inital: Obj: 39.9995 GradNorm: 739.7831
Iter 0: Obj: 39.9995 GradNorm: 739.7831
Iter 100: Obj: 2.0914 GradNorm: 180.6322
Iter 200: Obj: 1.0583 GradNorm: 434.3777
Iter 300: Obj: 0.4220 GradNorm: 252.9269
Iter 400: Obj: 0.2317 GradNorm: 168.8668
Iter 500: Obj: 0.2116 GradNorm: 384.2968
Iter 600: Obj: 0.1755 GradNorm: 647.6758
Iter 700: Obj: 0.1343 GradNorm: 786.2442
Iter 800: Obj: 0.1021 GradNorm: 810.3703
Iter 900: Obj: 0.0662 GradNorm: 810.3703
Iter 1000: Obj: 0.0539 GradNorm: 741.7304
Iter 1100: Obj: 0.0348 GradNorm: 621.4660
Iter 1200: Obj: 0.0338 GradNorm: 764.3429
Iter 1300: Obj: 0.0200 GradNorm: 556.2338
Iter 1400: Obj: 0.0182 GradNorm: 765.8329
Iter 1500: Obj: 0.0103 GradNorm: 579.8241
Iter 1600: Obj: 0.0075 GradNorm: 638.2570
Iter 1700: Obj: 0.0045 GradNorm: 320.1562
Iter 1800: Obj: 0.0035 GradNorm: 479.8625
Iter 1900: Obj: 0.0032 GradNorm: 647.1939
Iter 2000: Obj: 0.0022 GradNorm: 442.4975
Iter 2100: Obj: 0.0015 GradNorm: 61.0901
Iter 2200: Obj: 0.0016 GradNorm: 464.9430
Iter 2300: Obj: 0.0014 GradNorm: 382.5650
Iter 2400: Obj: 0.0011 GradNorm: 287.2281
Iter 2500: Obj: 0.0011 GradNorm: 355.6796
Iter 2600: Obj: 0.0010 GradNorm: 280.1357
Iter 2700: Obj: 0.0010 GradNorm: 289.6964
Iter 2800: Obj: 0.0010 GradNorm: 184.4234
Iter 2900: Obj: 0.0009 GradNorm: 246.5847
Iter 3000: Obj: 0.0009 GradNorm: 65.3299
Iter 3100: Obj: 0.0009 GradNorm: 185.9355
Iter 3200: Obj: 0.0009 GradNorm: 263.0209
Iter 3300: Obj: 0.0009 GradNorm: 300.3132
Iter 3400: Obj: 0.0009 GradNorm: 231.4044
Iter 3500: Obj: 0.0009 GradNorm: 226.3184
Iter 3600: Obj: 0.0009 GradNorm: 211.4237
Iter 3700: Obj: 0.0009 GradNorm: 233.2981
Iter 3800: Obj: 0.0009 GradNorm: 299.0853
Iter 3900: Obj: 0.0009 GradNorm: 262.4271
Time : 4.05252742767334
Obj : 0.0008940778156514197
Compare Barycenters in both methods
pl.figure(1, figsize=(6.4, 3))
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
else:
continue
# pl.plot(x, barys[i], 'g-*')
pl.plot(x, lp_bary, label='LP Barycenter')
pl.plot(x, l2_bary, label='L2 Barycenter')
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver')
pl.legend()
<matplotlib.legend.Legend object at 0x7fba814a2110>
More than 2 distributions
Generate 7 pseudorandom gaussian distributions with 50 bins.
n = 50 # nb bins
d = 7
vecsize = n * d
data = []
for i in range(d):
m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1)
a = ot.datasets.make_1D_gauss(n, m=m, s=5)
data.append(a)
x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
A = np.vstack(data).T
pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])
pl.title('Distributions')
pl.legend()
/home/circleci/project/examples/others/plot_dmmot.py:109: UserWarning: No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
pl.legend()
<matplotlib.legend.Legend object at 0x7fba68547d00>
Minimizing Distances Among Many Distributions
The objective being minimized is different for both methods, so the objective values cannot be compared.
# Perform gradient descent optimization using the d-MMOT method.
barys = ot.lp.dmmot_monge_1dgrid_optimize(
A, niters=3000, lr_init=1e-4, lr_decay=0.997)
# after minimization, any distribution can be used as a estimate of barycenter.
bary = barys[0]
# Compute 1D Wasserstein barycenter using the L2/LP method
weights = ot.unif(d)
l2_bary = A.dot(weights)
lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point',
verbose=False, log=True)
Inital: Obj: 37.1964 GradNorm: 284.3413
Iter 0: Obj: 37.1964 GradNorm: 280.9858
Iter 100: Obj: 3.3320 GradNorm: 136.2204
Iter 200: Obj: 0.7755 GradNorm: 156.3650
Iter 300: Obj: 0.4874 GradNorm: 229.6258
Iter 400: Obj: 0.3684 GradNorm: 238.1008
Iter 500: Obj: 0.3353 GradNorm: 280.3034
Iter 600: Obj: 0.2220 GradNorm: 267.4771
Iter 700: Obj: 0.1678 GradNorm: 284.3413
Iter 800: Obj: 0.1315 GradNorm: 284.3413
Iter 900: Obj: 0.0706 GradNorm: 271.7241
Iter 1000: Obj: 0.0567 GradNorm: 269.5960
Iter 1100: Obj: 0.0420 GradNorm: 250.8386
Iter 1200: Obj: 0.0345 GradNorm: 271.8676
Iter 1300: Obj: 0.0230 GradNorm: 230.8679
Iter 1400: Obj: 0.0145 GradNorm: 217.5960
Iter 1500: Obj: 0.0122 GradNorm: 244.9326
Iter 1600: Obj: 0.0089 GradNorm: 207.8076
Iter 1700: Obj: 0.0064 GradNorm: 175.3682
Iter 1800: Obj: 0.0054 GradNorm: 208.5713
Iter 1900: Obj: 0.0042 GradNorm: 196.3110
Iter 2000: Obj: 0.0035 GradNorm: 189.1930
Iter 2100: Obj: 0.0028 GradNorm: 152.9444
Iter 2200: Obj: 0.0025 GradNorm: 154.9903
Iter 2300: Obj: 0.0023 GradNorm: 165.5778
Iter 2400: Obj: 0.0020 GradNorm: 162.6161
Iter 2500: Obj: 0.0019 GradNorm: 148.9564
Iter 2600: Obj: 0.0018 GradNorm: 150.7780
Iter 2700: Obj: 0.0017 GradNorm: 160.8478
Iter 2800: Obj: 0.0017 GradNorm: 145.0931
Iter 2900: Obj: 0.0016 GradNorm: 128.1718
Compare Barycenters in both methods
<matplotlib.legend.Legend object at 0x7fba5ec3b430>
Compare with original distributions
pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
else:
continue
# pl.plot(x, barys[i], 'g')
pl.plot(x, l2_bary, 'k^', label='L2')
pl.plot(x, lp_bary, 'o', color='grey', label='LP')
pl.title('Barycenters')
pl.legend()
pl.show()
Total running time of the script: (0 minutes 26.468 seconds)