Quickstart Guide

Quickstart guide to the POT toolbox.

For better readability, only the use of POT is provided and the plotting code with matplotlib is hidden (but is available in the source file of the example).

Note

We use here the unified API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The classical API is still available (the unified API one is a convenient wrapper around the classical one) and we provide pointers to the classical API when needed.

# Author: Remi Flamary
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 4

# Import necessary libraries

import numpy as np
import pylab as pl

import ot

2D data example

We first generate two sets of samples in 2D that 25 and 50 samples respectively located on circles. The weights of the samples are uniform.

# Problem size
n1 = 25
n2 = 50

# Generate random data
np.random.seed(0)
a = ot.utils.unif(n1)  # weights of points in the source domain
b = ot.utils.unif(n2)  # weights of points in the target domain

x1 = np.random.randn(n1, 2)
x1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2

x2 = np.random.randn(n2, 2)
x2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4

# Compute the cost matrix
C = ot.dist(x1, x2)  # Squared Euclidean cost matrix by default
  • Source and target distributions
  • Cost matrix C
Text(0.5, 1.0, 'Cost matrix C')

Solving exact Optimal Transport

Solve the Optimal Transport problem between the samples

The ot.solve_sample() function can be used to solve the Optimal Transport problem between two sets of samples. The function takes as its two first arguments the positions of the source and target samples, and returns an ot.utils.OTResult object.

# Solve the OT problem
sol = ot.solve_sample(x1, x2, a, b)

# get the OT plan
P = sol.plan

# get the OT loss
loss = sol.value

# get the dual potentials
alpha, beta = sol.potentials

print(f"OT loss = {loss:1.3f}")
  • OT plan P loss=5.292, Dual potentials
  • OT plan
OT loss = 5.292

The figure above shows the Optimal Transport plan between the source and target samples. The color intensity represents the amount of mass transported between the samples. The dual potentials of the OT problem are also shown.

The weights of the samples in the source and target domains a and b are given to the function. If not provided, the weights are assumed to be uniform See ot.solve_sample() for more details.

The ot.utils.OTResult object contains the following attributes:

  • value: the value of the OT problem

  • plan: the OT matrix

  • potentials: Dual potentials of the OT problem

  • log: log dictionary of the solver

The OT matrix \(P\) is a matrix of size (n1, n2) where P[i,j] is the amount of mass transported from x1[i] to x2[j].

The OT loss is the sum of the element-wise product of the OT matrix and the cost matrix taken by default as the Squared Euclidean distance.

Optimal Transport problem with a custom cost matrix

The cost matrix can be customized by passing it to the more general ot.solve() function. The cost matrix should be a matrix of size (n1, n2) where C[i,j] is the cost of transporting mass from x1[i] to x2[j].

In this example, we use the Citybloc distance as the cost matrix.

# Compute the cost matrix
C_city = ot.dist(x1, x2, metric="cityblock")

# Solve the OT problem with the custom cost matrix
sol = ot.solve(C_city)
# the parameters a and b are not provided so uniform weights are assumed
P_city = sol.plan
# on empirical data the same can be done with ot.solve_sample :
# sol = ot.solve_sample(x1, x2, metric='cityblock')

# Compute the OT loss (equivalent to ot.solve(C).value)
loss_city = sol.value  # same as np.sum(P_city * C)
  • OT plan (Citybloc) loss=2.925
  • OT plan (Citybloc)

Note that we show here how to solve the OT problem with a custom cost matrix with the more general ot.solve() function. But the same can be done with the ot.solve_sample() function by passing metric='cityblock' as argument.

The cost matrix can be computed with the ot.dist() function which computes the pairwise distance between two sets of samples or can be provided directly as a matrix by the user when no samples are available.

Note

The examples above use the unified API of POT. The classic API is still available and and OT plan and loss can be computed with the ot.emd() and the ot.emd2() functions as below:

P = ot.emd(a, b, C)
loss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b

Sinkhorn and Regularized OT

Entropic OT with Sinkhorn algorithm

# Solve the Sinkhorn problem (just add reg parameter value)
sol = ot.solve_sample(x1, x2, a, b, reg=1e-1)

# get the OT plan and loss
P_sink = sol.plan
loss_sink = sol.value  # objective value of the Sinkhorn problem (incl. entropy)
loss_sink_linear = sol.value_linear  # np.sum(P_sink * C) linear part of loss
  • Sinkhorn OT plan loss=5.566
  • Sinkhorn OT plan
/home/circleci/project/ot/bregman/_sinkhorn.py:903: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
  warnings.warn(

The Sinkhorn algorithm solves the Entropic Regularized OT problem. The regularization strength can be controlled with the reg parameter. The Sinkhorn algorithm can be faster than the exact OT solver for large regularization strength but the solution is only an approximation of the exact OT problem and the OT plan is not sparse.

Quadratic Regularized OT

# Use quadratic regularization
P_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").plan

loss_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type="L2").value
OT plan loss=5.292, Sinkhorn plan loss=5.566, Quadratic reg plan loss=5.342

We plot above the OT plans obtained with different regularizations. The quadratic regularization is another common choice for regularized OT and preserves the sparsity of the OT plan.

Solve the Regularized OT problem with user-defined regularization

# Define a custom regularization function
def f(G):
    return 0.5 * np.sum(G**2)


def df(G):
    return G


P_reg = ot.solve_sample(x1, x2, a, b, reg=3, reg_type=(f, df)).plan
User-defined reg plan

Note

The examples above use the unified API of POT. The classic API is still available and and the entropic OT plan and loss can be computed with the ot.sinkhorn() # and ot.sinkhorn2() functions as below:

Gs = ot.sinkhorn(a, b, C, reg=1e-1)
loss_sink = ot.sinkhorn2(a, b, C, reg=1e-1)

For quadratic regularization, the ot.smooth.smooth_ot_dual() function can be used to compute the solution of the regularized OT problem. For user-defined regularization, the ot.optim.cg() function can be used to solve the regularized OT problem with Conditional Gradient algorithm.

Unbalanced and Partial OT

Unbalanced Optimal Transport

Unbalanced OT relaxes the marginal constraints and allows for the source and target total weights to be different. The ot.solve_sample() function can be used to solve the unbalanced OT problem by setting the marginal penalization unbalanced parameter to a positive value.

# Solve the unbalanced OT problem with KL penalization
P_unb_kl = ot.solve_sample(x1, x2, a, b, unbalanced=5e-2).plan

# Unbalanced with KL penalization ad KL regularization
P_unb_kl_reg = ot.solve_sample(
    x1, x2, a, b, unbalanced=5e-2, reg=1e-1
).plan  # also regularized

# Unbalanced with L2 penalization
P_unb_l2 = ot.solve_sample(x1, x2, a, b, unbalanced=7e1, unbalanced_type="L2").plan
Unbalanced KL plan, Unbalanced KL + reg plan, Unbalanced L2 plan

Note

Solving the unbalanced OT problem with the classic API can be done with the ot.unbalanced.sinkhorn_unbalanced() function as below:

G_unb_kl = ot.unbalanced.sinkhorn_unbalanced(a, b, C, eps=reg, alpha=unbalanced)

Partial Optimal Transport

# Solve the Unbalanced OT problem with TV penalization (equivalent)
P_part_pen = ot.solve_sample(x1, x2, a, b, unbalanced=3, unbalanced_type="TV").plan

# Solve the Partial OT problem with mass constraints (only classic API)
P_part_const = ot.partial.partial_wasserstein(a, b, C, m=0.5)  # 50% mass transported
Partial TV plan, Partial 50% mass plan

Gromov-Wasserstein and Fused Gromov-Wasserstein

Gromov-Wasserstein and Entropic GW

The Gromov-Wasserstein distance is a similarity measure between metric measure spaces. So it does not require the samples to be in the same space.

# Define the metric cost matrices in each spaces

C1 = ot.dist(x1, x1, metric="sqeuclidean")
C2 = ot.dist(x2, x2, metric="sqeuclidean")

C1 /= C1.max()
C2 /= C2.max()

# Solve the Gromov-Wasserstein problem
sol_gw = ot.solve_gromov(C1, C2, a=a, b=b)
P_gw = sol_gw.plan
loss_gw = sol_gw.value  # quadratic + reg if reg>0
loss_gw_quad = sol_gw.value_quad  # quadratic part of loss

# Solve the Entropic Gromov-Wasserstein problem
P_egw = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2).plan
GW plan, Entropic GW plan
/home/circleci/project/ot/bregman/_sinkhorn.py:667: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
  warnings.warn(

Note

The Gromov-Wasserstein problem can be solved with the classic API using the ot.gromov.gromov_wasserstein() function and the Entropic Gromov-Wasserstein problem can be solved with the ot.gromov.entropic_gromov_wasserstein() function.

P_gw = ot.gromov.gromov_wasserstein(C1, C2, a, b)
P_egw = ot.gromov.entropic_gromov_wasserstein(C1, C2, a, b, epsilon=reg)

loss_gw = ot.gromov.gromov_wasserstein2(C1, C2, a, b)
loss_egw = ot.gromov.entropic_gromov_wasserstein2(C1, C2, a, b, epsilon=reg)

Fused Gromov-Wasserstein

# Cost matrix
M = C / np.max(C)

# Solve FGW problem with alpha=0.1
sol = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1)
P_fgw = sol.plan
loss_fgw = sol.value
loss_fgw_linear = sol.value_linear  # linear part of loss (wrt M)
loss_fgw_quad = sol.value_quad  # quadratic part of loss (wrt C1 and C2)

# Solve entropic FGW problem with alpha=0.1
P_efgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1, reg=1e-3).plan
FGW plan, Entropic FGW plan

Note

The Fused Gromov-Wasserstein problem can be solved with the classic API using the ot.gromov.fused_gromov_wasserstein() function and the Entropic Fused Gromov-Wasserstein problem can be solved with the ot.gromov.entropic_fused_gromov_wasserstein() function.

P_fgw = ot.gromov.fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1)
P_efgw = ot.gromov.entropic_fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1, epsilon=reg)

loss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1)
loss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg)

Large scale OT

We discuss here strategies to solve large scale OT problems using approximations of the exact OT problem.

Large scale Sinkhorn

When having samples with a large number of points, the Sinkhorn algorithm can be implemented in a Lazy version which is more memory efficient and avoids the computation of the \(n \times m\) cost matrix.

POT provides two implementation of the lazy Sinkhorn algorithm that return their results in a lazy form of type ot.utils.LazyTensor. This object can be used to compute the loss or the OT plan in a lazy way or to recover its values in a dense form.

# Solve the Sinkhorn problem in a lazy way
sol = ot.solve_sample(x1, x2, a, b, reg=1e-1, lazy=True)

# Solve the sinkhoorn in a lazy way with geomloss
sol_geo = ot.solve_sample(x1, x2, a, b, reg=1e-1, method="geomloss", lazy=True)

# get the OT lazy plan and loss
P_sink_lazy = sol.lazy_plan

# recover values for Lazy plan
P12 = P_sink_lazy[1, 2]
P1dots = P_sink_lazy[1, :]
# convert to dense matrix !!warning this can be memory consuming
P_sink_lazy_dense = P_sink_lazy[:]
  • Lazy Sinkhorn OT plan
  • Lazy Sinkhorn OT plan
/home/circleci/project/ot/bregman/_empirical.py:253: UserWarning: Sinkhorn did not converge. You might want to increase the number of iterations `numItermax` or the regularization parameter `reg`.
  warnings.warn(
[KeOps] Generating code for Max_SumShiftExpWeight_Reduction reduction (with parameters 0) of formula [c-1/2*(d*Sum((a-b)**2)),1] with a=Var(0,2,0), b=Var(1,2,1), c=Var(2,1,1), d=Var(3,1,2) ... OK
[pyKeOps] Compiling pykeops cpp 9a6a1669ee module ... OK

Note

The lazy Sinkhorn algorithm can be found in the classic API with the ot.bregman.empirical_sinkhorn() function with parameter lazy=True. Similarly the geoloss implementation is available with the ot.bregman.empirical_sinkhorn2_geomloss().

the first example shows how to solve the Sinkhorn problem in a lazy way with the default POT implementation. The second example shows how to solve the Sinkhorn problem in a lazy way with the PyKeops/Geomloss implementation that provides a very efficient way to solve large scale problems on low dimensionality samples.

Factored and Low rank OT

The Sinkhorn algorithm can be implemented in a low rank version that approximates the OT plan with a low rank matrix. This can be useful to accelerate the computation of the OT plan for large scale problems. A similar non-regularized version of low rank factorization is also available.

# Solve the Factored OT problem (use lazy=True for large scale)
P_fact = ot.solve_sample(x1, x2, a, b, method="factored", rank=15).plan

P_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method="lowrank", rank=10).plan
  • Factored OT plan, Low rank OT plan
  • Factored OT plan, Low rank OT plan

Note

The factored OT problem can be solved with the classic API using the ot.factored.factored_optimal_transport() function and the low rank OT problem can be solved with the ot.lowrank.lowrank_sinkhorn() function.

Gaussian OT with Bures-Wasserstein

The Gaussian Wasserstein or Bures-Wasserstein distance is the Wasserstein distance between Gaussian distributions. It can be used as an approximation of the Wasserstein distance between empirical distributions by estimating the covariance matrices of the samples.

# Compute the Bures-Wasserstein distance
bw_value = ot.solve_sample(x1, x2, a, b, method="gaussian").value

print(f"Exact OT loss = {loss:1.3f}")
print(f"Bures-Wasserstein distance = {bw_value:1.3f}")
Exact OT loss = 5.292
Bures-Wasserstein distance = 4.558

Note

The Gaussian Wasserstein problem can be solved with the classic API using the ot.gaussian.empirical_bures_wasserstein_distance() function.

Comparing all OT plans

The figure below shows all the OT plans computed in this example. The color intensity represents the amount of mass transported between the samples.

# plot all plans
OT plan, Sinkhorn plan, Quadratic reg. plan, Unbalanced KL plan, Unbalanced KL + reg plan, Unbalanced L2 plan, Partial 50% mass plan, Factored OT plan, Low rank OT plan, GW plan, Entropic GW plan, Fused GW plan

Total running time of the script: (0 minutes 16.179 seconds)

Gallery generated by Sphinx-Gallery