Continuous OT plan estimation with Pytorch

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import numpy as np
import matplotlib.pyplot as pl
import torch
from torch import nn
import ot
import ot.plot

Data generation

torch.manual_seed(42)
np.random.seed(42)

n_source_samples = 10000
n_target_samples = 10000
theta = 2 * np.pi / 20
noise_level = 0.1

Xs = np.random.randn(n_source_samples, 2) * 0.5
Xt = np.random.randn(n_target_samples, 2) * 2

# one of the target mode changes its variance (no linear mapping)
Xt = Xt + 4

Plot data

nvisu = 300
pl.figure(1, (5, 5))
pl.clf()
pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5)
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Source and target distributions')
Source and target distributions

Out:

Text(0.5, 1.0, 'Source and target distributions')

Convert data to torch tensors

Estimating deep dual variables for entropic OT

torch.manual_seed(42)

# define the MLP model


class Potential(torch.nn.Module):
    def __init__(self):
        super(Potential, self).__init__()
        self.fc1 = nn.Linear(2, 200)
        self.fc2 = nn.Linear(200, 1)
        self.relu = torch.nn.ReLU()  # instead of Heaviside step fn

    def forward(self, x):
        output = self.fc1(x)
        output = self.relu(output)  # instead of Heaviside step fn
        output = self.fc2(output)
        return output.ravel()


u = Potential().double()
v = Potential().double()

reg = 1

optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)

# number of iteration
n_iter = 1000
n_batch = 500


losses = []

for i in range(n_iter):

    # generate noise samples

    iperms = torch.randint(0, n_source_samples, (n_batch,))
    ipermt = torch.randint(0, n_target_samples, (n_batch,))

    xsi = xs[iperms]
    xti = xt[ipermt]

    # minus because we maximize te dual loss
    loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg)
    losses.append(float(loss.detach()))

    if i % 10 == 0:
        print("Iter: {:3d}, loss={}".format(i, losses[-1]))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


pl.figure(2)
pl.plot(losses)
pl.grid()
pl.title('Dual objective (negative)')
pl.xlabel("Iterations")
Dual objective (negative)

Out:

Iter:   0, loss=0.2393484646308117
Iter:  10, loss=-12.642282658482282
Iter:  20, loss=-16.17414737436737
Iter:  30, loss=-19.163954594319062
Iter:  40, loss=-21.568419626220432
Iter:  50, loss=-25.301421042321348
Iter:  60, loss=-27.83083778430649
Iter:  70, loss=-31.87651959359454
Iter:  80, loss=-32.21879930758375
Iter:  90, loss=-33.4451848716634
Iter: 100, loss=-34.85373096612886
Iter: 110, loss=-34.34720398072969
Iter: 120, loss=-35.710736296778066
Iter: 130, loss=-35.41460794588291
Iter: 140, loss=-35.427335276898006
Iter: 150, loss=-36.95647523380266
Iter: 160, loss=-35.63352357746055
Iter: 170, loss=-34.147411707619916
Iter: 180, loss=-35.38045016076242
Iter: 190, loss=-34.20851722920263
Iter: 200, loss=-34.0483252573866
Iter: 210, loss=-32.920184012645834
Iter: 220, loss=-33.28786346616246
Iter: 230, loss=-34.595481968365874
Iter: 240, loss=-36.74007551169939
Iter: 250, loss=-34.87532094924726
Iter: 260, loss=-34.16038650643604
Iter: 270, loss=-35.51952191315499
Iter: 280, loss=-38.21124977745394
Iter: 290, loss=-34.776473454024725
Iter: 300, loss=-35.236113046850186
Iter: 310, loss=-37.17877897507476
Iter: 320, loss=-35.67523008941068
Iter: 330, loss=-35.47856020676345
Iter: 340, loss=-34.20515530483855
Iter: 350, loss=-34.31842178792488
Iter: 360, loss=-37.41746085650228
Iter: 370, loss=-36.49945342906269
Iter: 380, loss=-37.101409600220826
Iter: 390, loss=-35.95709453591492
Iter: 400, loss=-35.351964349897315
Iter: 410, loss=-36.097523666610016
Iter: 420, loss=-35.76370892349192
Iter: 430, loss=-35.03460532012972
Iter: 440, loss=-36.117010646149254
Iter: 450, loss=-36.63143924731369
Iter: 460, loss=-36.3493275492541
Iter: 470, loss=-36.85056949023975
Iter: 480, loss=-35.673259773400964
Iter: 490, loss=-35.50129693277133
Iter: 500, loss=-36.691373671209874
Iter: 510, loss=-38.2506577228827
Iter: 520, loss=-37.48779395563362
Iter: 530, loss=-36.79890479507191
Iter: 540, loss=-37.57060888796476
Iter: 550, loss=-36.5638137066962
Iter: 560, loss=-36.46562709782847
Iter: 570, loss=-36.40687844211682
Iter: 580, loss=-37.2092261947815
Iter: 590, loss=-38.24178255185049
Iter: 600, loss=-34.83767190886772
Iter: 610, loss=-37.091495197788845
Iter: 620, loss=-37.2269199634254
Iter: 630, loss=-36.79958582295997
Iter: 640, loss=-37.25712095782331
Iter: 650, loss=-35.693215050027426
Iter: 660, loss=-36.68384667580914
Iter: 670, loss=-37.19213769075985
Iter: 680, loss=-37.91790237056651
Iter: 690, loss=-36.30587890071249
Iter: 700, loss=-36.76524124674425
Iter: 710, loss=-36.096652875261434
Iter: 720, loss=-36.640525064634964
Iter: 730, loss=-36.63630669836857
Iter: 740, loss=-35.22361859806151
Iter: 750, loss=-36.922494326811055
Iter: 760, loss=-36.97476232035201
Iter: 770, loss=-37.234520256435054
Iter: 780, loss=-37.16928305689463
Iter: 790, loss=-38.25788596475691
Iter: 800, loss=-37.23397446378046
Iter: 810, loss=-36.1721306311481
Iter: 820, loss=-37.247867302837534
Iter: 830, loss=-36.96014658076261
Iter: 840, loss=-37.03767527010373
Iter: 850, loss=-38.030040026953046
Iter: 860, loss=-37.1448266201477
Iter: 870, loss=-35.932275800816825
Iter: 880, loss=-37.14373204690507
Iter: 890, loss=-36.26008919084346
Iter: 900, loss=-36.517689779489814
Iter: 910, loss=-39.39905188612141
Iter: 920, loss=-36.460772222267416
Iter: 930, loss=-37.018913815021016
Iter: 940, loss=-36.499518213813715
Iter: 950, loss=-36.96986872327708
Iter: 960, loss=-37.024141665861244
Iter: 970, loss=-36.88576653718013
Iter: 980, loss=-38.428442172282146
Iter: 990, loss=-35.10248001587743

Text(0.5, 23.52222222222222, 'Iterations')

Plot the density on arget for a given source sample

nv = 100
xl = np.linspace(ax_bounds[0], ax_bounds[1], nv)
yl = np.linspace(ax_bounds[2], ax_bounds[3], nv)

XX, YY = np.meshgrid(xl, yl)

xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1)

wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2))
wxg = wxg / np.sum(wxg)

xg = torch.tensor(xg)
wxg = torch.tensor(wxg)


pl.figure(4, (12, 4))
pl.clf()
pl.subplot(1, 3, 1)

iv = 2
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
Gg = Gg.reshape((nv, nv)).detach().numpy()

pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Density of transported source sample')

pl.subplot(1, 3, 2)

iv = 3
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
Gg = Gg.reshape((nv, nv)).detach().numpy()

pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Density of transported source sample')

pl.subplot(1, 3, 3)

iv = 6
Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg)
Gg = Gg.reshape((nv, nv)).detach().numpy()

pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05)
pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05)
pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0')
pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample')
pl.legend(loc=0)
ax_bounds = pl.axis()
pl.title('Density of transported source sample')
Density of transported source sample, Density of transported source sample, Density of transported source sample

Out:

Text(0.5, 1.0, 'Density of transported source sample')

Total running time of the script: ( 1 minutes 28.178 seconds)

Gallery generated by Sphinx-Gallery