Dual OT solvers for entropic and quadratic regularized OT with Pytorch

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 3

import numpy as np
import matplotlib.pyplot as pl
import torch
import ot
import ot.plot

Data generation

torch.manual_seed(1)

n_source_samples = 100
n_target_samples = 100
theta = 2 * np.pi / 20
noise_level = 0.1

Xs, ys = ot.datasets.make_data_classif(
    'gaussrot', n_source_samples, nz=noise_level)
Xt, yt = ot.datasets.make_data_classif(
    'gaussrot', n_target_samples, theta=theta, nz=noise_level)

# one of the target mode changes its variance (no linear mapping)
Xt[yt == 2] *= 3
Xt = Xt + 4

Plot data

pl.figure(1, (10, 5))
pl.clf()
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples')
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')
Source and target distributions

Out:

Text(0.5, 1.0, 'Source and target distributions')

Convert data to torch tensors

Estimating dual variables for entropic OT

u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)

reg = 0.5

optimizer = torch.optim.Adam([u, v], lr=1)

# number of iteration
n_iter = 200


losses = []

for i in range(n_iter):

    # generate noise samples

    # minus because we maximize te dual loss
    loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
    losses.append(float(loss.detach()))

    if i % 10 == 0:
        print("Iter: {:3d}, loss={}".format(i, losses[-1]))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


pl.figure(2)
pl.plot(losses)
pl.grid()
pl.title('Dual objective (negative)')
pl.xlabel("Iterations")

Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)
Dual objective (negative)

Out:

Iter:   0, loss=0.20204949002247305
Iter:  10, loss=-19.59848703458015
Iter:  20, loss=-31.13869587996386
Iter:  30, loss=-35.31050127275935
Iter:  40, loss=-38.52289314113808
Iter:  50, loss=-40.28744798684676
Iter:  60, loss=-41.2688200716847
Iter:  70, loss=-41.76634561497488
Iter:  80, loss=-41.86524666837034
Iter:  90, loss=-41.92481373332549
Iter: 100, loss=-41.94786763969571
Iter: 110, loss=-41.959601897693986
Iter: 120, loss=-41.96568869115486
Iter: 130, loss=-41.96962380032497
Iter: 140, loss=-41.972649295229154
Iter: 150, loss=-41.975340650881094
Iter: 160, loss=-41.97781537956428
Iter: 170, loss=-41.98019165366636
Iter: 180, loss=-41.982491335324184
Iter: 190, loss=-41.984716231976606

Plot teh estimated entropic OT plan

pl.figure(3, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
pl.legend(loc=0)
pl.title('Source and target distributions')
Source and target distributions

Out:

Text(0.5, 1.0, 'Source and target distributions')

Estimating dual variables for quadratic OT

u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)

reg = 0.01

optimizer = torch.optim.Adam([u, v], lr=1)

# number of iteration
n_iter = 200


losses = []


for i in range(n_iter):

    # generate noise samples

    # minus because we maximize te dual loss
    loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
    losses.append(float(loss.detach()))

    if i % 10 == 0:
        print("Iter: {:3d}, loss={}".format(i, losses[-1]))

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


pl.figure(4)
pl.plot(losses)
pl.grid()
pl.title('Dual objective (negative)')
pl.xlabel("Iterations")

Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)
Dual objective (negative)

Out:

Iter:   0, loss=-0.0018442196020623663
Iter:  10, loss=-19.61204302244388
Iter:  20, loss=-30.862529273798163
Iter:  30, loss=-35.059752333875274
Iter:  40, loss=-38.24608161649095
Iter:  50, loss=-39.994411779113314
Iter:  60, loss=-41.050924337529416
Iter:  70, loss=-41.57251015120981
Iter:  80, loss=-41.720173488718714
Iter:  90, loss=-41.807425081072935
Iter: 100, loss=-41.83759189059595
Iter: 110, loss=-41.85046471143942
Iter: 120, loss=-41.85734468841109
Iter: 130, loss=-41.860904108172214
Iter: 140, loss=-41.86351529045967
Iter: 150, loss=-41.8657026282564
Iter: 160, loss=-41.86779447649588
Iter: 170, loss=-41.869903524245494
Iter: 180, loss=-41.87205340431508
Iter: 190, loss=-41.874263012636334

Plot the estimated quadratic OT plan

pl.figure(5, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2)
pl.legend(loc=0)
pl.title('OT plan with quadratic regularization')
OT plan with quadratic regularization

Out:

Text(0.5, 1.0, 'OT plan with quadratic regularization')

Total running time of the script: ( 0 minutes 40.544 seconds)

Gallery generated by Sphinx-Gallery