Note
Go to the end to download the full example code.
Dual OT solvers for entropic and quadratic regularized OT with Pytorch
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
import numpy as np
import matplotlib.pyplot as pl
import torch
import ot
import ot.plot
Data generation
torch.manual_seed(1)
n_source_samples = 100
n_target_samples = 100
theta = 2 * np.pi / 20
noise_level = 0.1
Xs, ys = ot.datasets.make_data_classif("gaussrot", n_source_samples, nz=noise_level)
Xt, yt = ot.datasets.make_data_classif(
"gaussrot", n_target_samples, theta=theta, nz=noise_level
)
# one of the target mode changes its variance (no linear mapping)
Xt[yt == 2] *= 3
Xt = Xt + 4
Plot data
pl.figure(1, (10, 5))
pl.clf()
pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples")
pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples")
pl.legend(loc=0)
pl.title("Source and target distributions")

Text(0.5, 1.0, 'Source and target distributions')
Convert data to torch tensors
xs = torch.tensor(Xs)
xt = torch.tensor(Xt)
Estimating dual variables for entropic OT
u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)
reg = 0.5
optimizer = torch.optim.Adam([u, v], lr=1)
# number of iteration
n_iter = 200
losses = []
for i in range(n_iter):
# generate noise samples
# minus because we maximize the dual loss
loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)
losses.append(float(loss.detach()))
if i % 10 == 0:
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
loss.backward()
optimizer.step()
optimizer.zero_grad()
pl.figure(2)
pl.plot(losses)
pl.grid()
pl.title("Dual objective (negative)")
pl.xlabel("Iterations")
Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)

Iter: 0, loss=0.2020494900224731
Iter: 10, loss=-19.60366362085319
Iter: 20, loss=-31.555999463827295
Iter: 30, loss=-35.85662090130287
Iter: 40, loss=-38.37829707933387
Iter: 50, loss=-39.47961718550262
Iter: 60, loss=-39.90703272470927
Iter: 70, loss=-40.10492546954223
Iter: 80, loss=-40.17435202660893
Iter: 90, loss=-40.20795600896186
Iter: 100, loss=-40.221724226027234
Iter: 110, loss=-40.2289302291857
Iter: 120, loss=-40.23253276238543
Iter: 130, loss=-40.234551907323414
Iter: 140, loss=-40.235711515004226
Iter: 150, loss=-40.23647703543624
Iter: 160, loss=-40.23700391316219
Iter: 170, loss=-40.23738390884928
Iter: 180, loss=-40.237668835398956
Iter: 190, loss=-40.237889161055016
Plot the estimated entropic OT plan
pl.figure(3, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2)
pl.legend(loc=0)
pl.title("Source and target distributions")

Text(0.5, 1.0, 'Source and target distributions')
Estimating dual variables for quadratic OT
u = torch.randn(n_source_samples, requires_grad=True)
v = torch.randn(n_source_samples, requires_grad=True)
reg = 0.01
optimizer = torch.optim.Adam([u, v], lr=1)
# number of iteration
n_iter = 200
losses = []
for i in range(n_iter):
# generate noise samples
# minus because we maximize the dual loss
loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)
losses.append(float(loss.detach()))
if i % 10 == 0:
print("Iter: {:3d}, loss={}".format(i, losses[-1]))
loss.backward()
optimizer.step()
optimizer.zero_grad()
pl.figure(4)
pl.plot(losses)
pl.grid()
pl.title("Dual objective (negative)")
pl.xlabel("Iterations")
Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)

Iter: 0, loss=-0.0018442196020623663
Iter: 10, loss=-19.639168295306263
Iter: 20, loss=-31.2546764762086
Iter: 30, loss=-35.431175810389306
Iter: 40, loss=-38.0347919368864
Iter: 50, loss=-39.196977071430894
Iter: 60, loss=-39.687629920963104
Iter: 70, loss=-39.921383085201164
Iter: 80, loss=-40.02912736914793
Iter: 90, loss=-40.080626182150986
Iter: 100, loss=-40.100546169775114
Iter: 110, loss=-40.111312566901844
Iter: 120, loss=-40.11712450121973
Iter: 130, loss=-40.12056797459418
Iter: 140, loss=-40.122300943897656
Iter: 150, loss=-40.12316242431165
Iter: 160, loss=-40.123602787538786
Iter: 170, loss=-40.12385478500847
Iter: 180, loss=-40.12401954754222
Iter: 190, loss=-40.12413320707049
Plot the estimated quadratic OT plan
pl.figure(5, (10, 5))
pl.clf()
ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)
pl.scatter(Xs[:, 0], Xs[:, 1], marker="+", label="Source samples", zorder=2)
pl.scatter(Xt[:, 0], Xt[:, 1], marker="o", label="Target samples", zorder=2)
pl.legend(loc=0)
pl.title("OT plan with quadratic regularization")

Text(0.5, 1.0, 'OT plan with quadratic regularization')
Total running time of the script: (0 minutes 13.858 seconds)