Note
Click here to download the full example code
Stochastic examples¶
This example is designed to show how to use the stochatic optimization algorithms for discrete and semi-continuous measures from the POT library.
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)
# Author: Kilian Fatras <kilian.fatras@gmail.com>
#
# License: MIT License
import matplotlib.pylab as pl
import numpy as np
import ot
import ot.plot
Compute the Transportation Matrix for the Semi-Dual Problem¶
Discrete case¶
Sample two discrete measures for the discrete case and compute their cost matrix c.
Call the “SAG” method to find the transportation matrix in the discrete case
method = "SAG"
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
numItermax)
print(sag_pi)
Out:
[[2.55553509e-02 9.96395660e-02 1.76579142e-02 4.31178196e-06]
[1.21640234e-01 1.25357448e-02 1.30225078e-03 7.37891338e-03]
[3.56123975e-03 7.61451746e-02 6.31505947e-02 1.33831456e-07]
[2.61515202e-02 3.34246014e-02 8.28734709e-02 4.07550428e-04]
[9.85500870e-03 7.52288517e-04 1.08262628e-02 1.21423583e-01]
[2.16904253e-02 9.03825797e-04 1.87178503e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177216e-02 2.39440107e-03]]
Semi-Continuous Case¶
Sample one general measure a, one discrete measures b for the semicontinous case, the points where source and target measures are defined and compute the cost matrix.
Call the “ASGD” method to find the transportation matrix in the semicontinous case.
Out:
[3.90384008 7.64209593 3.91518482 2.64758955 1.46167873 3.29518073
2.76727099] [-2.49635294 -2.44523616 -0.88563215 5.82722125]
[[2.54246365e-02 9.89745062e-02 1.84534585e-02 4.54156507e-06]
[1.21233560e-01 1.24742477e-02 1.36334480e-03 7.78599048e-03]
[3.48643665e-03 7.44288938e-02 6.49416737e-02 1.38712186e-07]
[2.54131777e-02 3.24299950e-02 8.45946755e-02 4.19294565e-04]
[9.35265652e-03 7.12821872e-04 1.07924987e-02 1.21999166e-01]
[2.06712576e-02 8.60007373e-04 1.87378719e-03 1.19452091e-01]
[4.04847539e-02 2.58785901e-02 7.40235955e-02 2.47020341e-03]]
Compare the results with the Sinkhorn algorithm
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
Out:
[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
[1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
[3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
[2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
[9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
[2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]
Plot Transportation Matrices¶
For SAG
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
pl.show()

For ASGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
pl.show()

For Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()

Compute the Transportation Matrix for the Dual Problem¶
Semi-continuous case¶
Sample one general measure a, one discrete measures b for the semi-continuous case and compute the cost matrix c.
n_source = 7
n_target = 4
reg = 1
numItermax = 100000
lr = 0.1
batch_size = 3
log = True
a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)
rng = np.random.RandomState(0)
X_source = rng.randn(n_source, 2)
Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
Call the “SGD” dual method to find the transportation matrix in the semi-continuous case
sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
batch_size, numItermax,
lr, log=log)
print(log_sgd['alpha'], log_sgd['beta'])
print(sgd_dual_pi)
Out:
[0.91979484 2.79773247 1.06997942 0.02111325 0.6066282 1.81574371
0.11264079] [0.34327594 0.47420468 1.57241179 4.95374025]
[[2.20057171e-02 9.27824871e-02 1.09053175e-02 9.59186767e-08]
[1.63296807e-02 1.81983115e-03 1.25383476e-04 2.55909005e-05]
[3.46704857e-03 8.01644112e-02 4.40941924e-02 3.36596525e-09]
[3.14506745e-02 4.34690118e-02 7.14815214e-02 1.26621394e-05]
[6.80497308e-02 5.61738692e-03 5.36158505e-02 2.16603342e-02]
[8.05546690e-02 3.62984339e-03 4.98567370e-03 1.13588520e-02]
[4.87119380e-02 3.37245692e-02 6.08126387e-02 7.25259602e-05]]
Compare the results with the Sinkhorn algorithm¶
Call the Sinkhorn algorithm from POT
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
Out:
[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
[1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
[3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
[2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
[9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
[2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]
Plot Transportation Matrices¶
For SGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
pl.show()

For Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()

Total running time of the script: ( 0 minutes 7.910 seconds)