# Stochastic examples

This example is designed to show how to use the stochatic optimization algorithms for discrete and semi-continuous measures from the POT library.

[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. Stochastic Optimization for Large-scale Optimal Transport. Advances in Neural Information Processing Systems (2016).

[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018)

```# Author: Kilian Fatras <kilian.fatras@gmail.com>
#

import matplotlib.pylab as pl
import numpy as np
import ot
import ot.plot
```

## Compute the Transportation Matrix for the Semi-Dual Problem

### Discrete case

Sample two discrete measures for the discrete case and compute their cost matrix c.

```n_source = 7
n_target = 4
reg = 1
numItermax = 1000

a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)

rng = np.random.RandomState(0)
X_source = rng.randn(n_source, 2)
Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
```

Call the “SAG” method to find the transportation matrix in the discrete case

```method = "SAG"
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
numItermax)
print(sag_pi)
```

Out:

```[[2.55553509e-02 9.96395660e-02 1.76579142e-02 4.31178196e-06]
[1.21640234e-01 1.25357448e-02 1.30225078e-03 7.37891338e-03]
[3.56123975e-03 7.61451746e-02 6.31505947e-02 1.33831456e-07]
[2.61515202e-02 3.34246014e-02 8.28734709e-02 4.07550428e-04]
[9.85500870e-03 7.52288517e-04 1.08262628e-02 1.21423583e-01]
[2.16904253e-02 9.03825797e-04 1.87178503e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177216e-02 2.39440107e-03]]
```

### Semi-Continuous Case

Sample one general measure a, one discrete measures b for the semicontinous case, the points where source and target measures are defined and compute the cost matrix.

```n_source = 7
n_target = 4
reg = 1
numItermax = 1000
log = True

a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)

rng = np.random.RandomState(0)
X_source = rng.randn(n_source, 2)
Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
```

Call the “ASGD” method to find the transportation matrix in the semicontinous case.

```method = "ASGD"
asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
numItermax, log=log)
print(log_asgd['alpha'], log_asgd['beta'])
print(asgd_pi)
```

Out:

```[3.83505858 7.69070071 3.81966364 2.57264415 1.57300062 3.41693415
2.71184982] [-2.55698637 -2.35275077 -0.77893768  5.68867482]
[[2.23383226e-02 1.01348613e-01 1.91665156e-02 3.69117098e-06]
[1.19784020e-01 1.43644516e-02 1.59239688e-03 7.11627477e-03]
[2.98239176e-03 7.42032862e-02 6.56713551e-02 1.09764106e-07]
[2.21910493e-02 3.30038351e-02 8.73235697e-02 3.38688860e-04]
[9.83894886e-03 8.73963705e-04 1.34216315e-02 1.18722599e-01]
[2.19740936e-02 1.06547903e-03 2.35469100e-03 1.17462879e-01]
[3.60487015e-02 2.68557568e-02 7.79180160e-02 2.03466847e-03]]
```

Compare the results with the Sinkhorn algorithm

```sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
```

Out:

```[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
[1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
[3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
[2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
[9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
[2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]
```

### Plot Transportation Matrices

For SAG

```pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
pl.show()
```

For ASGD

```pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
pl.show()
```

For Sinkhorn

```pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()
```

## Compute the Transportation Matrix for the Dual Problem

### Semi-continuous case

Sample one general measure a, one discrete measures b for the semi-continuous case and compute the cost matrix c.

```n_source = 7
n_target = 4
reg = 1
numItermax = 100000
lr = 0.1
batch_size = 3
log = True

a = ot.utils.unif(n_source)
b = ot.utils.unif(n_target)

rng = np.random.RandomState(0)
X_source = rng.randn(n_source, 2)
Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
```

Call the “SGD” dual method to find the transportation matrix in the semi-continuous case

```sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
batch_size, numItermax,
lr, log=log)
print(log_sgd['alpha'], log_sgd['beta'])
print(sgd_dual_pi)
```

Out:

```[0.91451453 2.79429644 1.06820713 0.01751949 0.60503425 1.80794178
0.11101296] [0.34690791 0.47652001 1.57233368 4.92276498]
[[2.19694740e-02 9.25077959e-02 1.08470384e-02 9.25033762e-08]
[1.63328805e-02 1.81779279e-03 1.24943634e-04 2.47252638e-05]
[3.47350215e-03 8.02079550e-02 4.40126757e-02 3.25752342e-09]
[3.14518763e-02 4.34134752e-02 7.12195317e-02 1.22319009e-05]
[6.81885585e-02 5.62144061e-03 5.35262761e-02 2.09662386e-02]
[8.02194585e-02 3.60998242e-03 4.94654076e-03 1.09268183e-02]
[4.88096613e-02 3.37477627e-02 6.07089842e-02 7.01995203e-05]]
```

### Compare the results with the Sinkhorn algorithm

Call the Sinkhorn algorithm from POT

```sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
```

Out:

```[[2.55553508e-02 9.96395661e-02 1.76579142e-02 4.31178193e-06]
[1.21640234e-01 1.25357448e-02 1.30225079e-03 7.37891333e-03]
[3.56123974e-03 7.61451746e-02 6.31505947e-02 1.33831455e-07]
[2.61515201e-02 3.34246014e-02 8.28734709e-02 4.07550425e-04]
[9.85500876e-03 7.52288523e-04 1.08262629e-02 1.21423583e-01]
[2.16904255e-02 9.03825804e-04 1.87178504e-03 1.18391107e-01]
[4.15462212e-02 2.65987989e-02 7.23177217e-02 2.39440105e-03]]
```

### Plot Transportation Matrices

For SGD

```pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
pl.show()
```

For Sinkhorn

```pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
pl.show()
```

Total running time of the script: ( 0 minutes 7.677 seconds)

Gallery generated by Sphinx-Gallery