Stochastic examples

This example is designed to show how to use the stochatic optimization algorithms for discrete and semi-continuous measures from the POT library.

[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).

[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)

# Author: Kilian Fatras <kilian.fatras@gmail.com>
#
# License: MIT License

import matplotlib.pylab as pl
import numpy as np
import ot
import ot.plot

Compute the Transportation Matrix for the Semi-Dual Problem

Discrete case

Sample two discrete measures for the discrete case and compute their cost matrix c.

Call the “SAG” method to find the transportation matrix in the discrete case

Out:

[[2.55553509e-02 9.96395660e-02 1.76579142e-02 4.31178196e-06]
 [1.21640234e-01 1.25357448e-02 1.30225078e-03 7.37891338e-03]
 [3.56123975e-03 7.61451746e-02 6.31505947e-02 1.33831456e-07]
 [2.61515202e-02 3.34246014e-02 8.28734709e-02 4.07550428e-04]
 [9.85500870e-03 7.52288517e-04 1.08262628e-02 1.21423583e-01]
 [2.16904253e-02 9.03825797e-04 1.87178503e-03 1.18391107e-01]
 [4.15462212e-02 2.65987989e-02 7.23177216e-02 2.39440107e-03]]

Semi-Continuous Case

Sample one general measure a, one discrete measures b for the semicontinous case, the points where source and target measures are defined and compute the cost matrix.

Call the “ASGD” method to find the transportation matrix in the semicontinous case.

Out:

[3.90384008 7.64209593 3.91518482 2.64758955 1.46167873 3.29518073
 2.76727099] [-2.49635294 -2.44523616 -0.88563215  5.82722125]
[[2.54246365e-02 9.89745062e-02 1.84534585e-02 4.54156507e-06]
 [1.21233560e-01 1.24742477e-02 1.36334480e-03 7.78599048e-03]
 [3.48643665e-03 7.44288938e-02 6.49416737e-02 1.38712186e-07]
 [2.54131777e-02 3.24299950e-02 8.45946755e-02 4.19294565e-04]
 [9.35265652e-03 7.12821872e-04 1.07924987e-02 1.21999166e-01]
 [2.06712576e-02 8.60007373e-04 1.87378719e-03 1.19452091e-01]
 [4.04847539e-02 2.58785901e-02 7.40235955e-02 2.47020341e-03]]

Compare the results with the Sinkhorn algorithm

Out:

[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
 [1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
 [3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
 [2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
 [9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
 [2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
 [4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]

Plot Transportation Matrices

For SAG

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
pl.show()
semi-dual : OT matrix SAG

For ASGD

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
pl.show()
semi-dual : OT matrix ASGD

For Sinkhorn

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()
OT matrix Sinkhorn

Compute the Transportation Matrix for the Dual Problem

Semi-continuous case

Sample one general measure a, one discrete measures b for the semi-continuous case and compute the cost matrix c.

Call the “SGD” dual method to find the transportation matrix in the semi-continuous case

Out:

[0.91979484 2.79773247 1.06997942 0.02111325 0.6066282  1.81574371
 0.11264079] [0.34327594 0.47420468 1.57241179 4.95374025]
[[2.20057171e-02 9.27824871e-02 1.09053175e-02 9.59186767e-08]
 [1.63296807e-02 1.81983115e-03 1.25383476e-04 2.55909005e-05]
 [3.46704857e-03 8.01644112e-02 4.40941924e-02 3.36596525e-09]
 [3.14506745e-02 4.34690118e-02 7.14815214e-02 1.26621394e-05]
 [6.80497308e-02 5.61738692e-03 5.36158505e-02 2.16603342e-02]
 [8.05546690e-02 3.62984339e-03 4.98567370e-03 1.13588520e-02]
 [4.87119380e-02 3.37245692e-02 6.08126387e-02 7.25259602e-05]]

Compare the results with the Sinkhorn algorithm

Call the Sinkhorn algorithm from POT

Out:

[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
 [1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
 [3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
 [2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
 [9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
 [2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
 [4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]

Plot Transportation Matrices

For SGD

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
pl.show()
dual : OT matrix SGD

For Sinkhorn

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()
OT matrix Sinkhorn

Total running time of the script: ( 0 minutes 7.910 seconds)

Gallery generated by Sphinx-Gallery