Note
Go to the end to download the full example code.
Sliced Unbalanced optimal transport
This example illustrates the behavior of Sliced UOT versus Unbalanced Sliced OT, introduced in [82]. The first one removes outliers on each slice while the second one removes outliers of the original marginals.
[82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research.
# Author: Clément Bonet <clement.bonet.mapp@polytechnique.edu>
# Nicolas Courty <nicolas.courty@irisa.fr>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 4
import numpy as np
import matplotlib.pylab as pl
import ot
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from sklearn.neighbors import KernelDensity
Generate data
np.random.seed(42)
n_samples = 25 # 500
nb_outliers = 10 # 200
mu_s = np.array([0, 0]) - 0.5
cov_s = 0.2**2 * np.array([[1, 0], [0, 1]])
mu_s_outliers = -np.array([2, 0.5])
cov_s_outliers = 0.05**2 * np.array([[1, 0], [0, 1]])
mu_t = np.array([0, 0]) + 1.5
cov_t = 0.2**2 * np.array([[1, 0], [0, 1]])
def generate_dataset(n_samples):
# Generate source data (with outliers)
Xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
Xs_outlier = ot.datasets.make_2D_samples_gauss(
nb_outliers, mu_s_outliers, cov_s_outliers
)
Xs = np.vstack((Xs, Xs_outlier))
Xs_torch = torch.from_numpy(Xs).type(torch.float)
# Generate target data
Xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_t, cov_t)
Xt_torch = torch.from_numpy(Xt).type(torch.float)
return Xs_torch, Xt_torch
Xs, Xt = generate_dataset(n_samples)
pl.figure(1)
pl.scatter(Xs[:, 0], Xs[:, 1], color="blue", label="Source data")
pl.scatter(Xt[:, 0], Xt[:, 1], color="red", label="Target data")
pl.xlim(-2.4, 2.4)
pl.ylim(-1, 2.2)
pl.legend()
pl.show()

Compute SUOT and USOT
p = 2
num_proj = 180
a = torch.ones(Xs.shape[0], dtype=torch.float)
b = torch.ones(Xt.shape[0], dtype=torch.float)
# construct projections
thetas = np.linspace(0, np.pi, num_proj)
dir = np.array([(np.cos(theta), np.sin(theta)) for theta in thetas])
dir_torch = torch.from_numpy(dir).type(torch.float)
# Coordinates of the projections
Xps = (Xs @ dir_torch.T).T # shape (n_projs, n)
Xpt = (Xt @ dir_torch.T).T
# Projections on the lines
projs_Xps = Xps[:, :, None] * dir_torch[:, None, :] # shape (n_projs, n, p)
projs_Xpt = Xpt[:, :, None] * dir_torch[:, None, :]
# Compute SUOT
rho1_SUOT = 1
rho2_SUOT = 1
_, log = ot.unbalanced.sliced_unbalanced_ot(
Xs,
Xt,
(rho1_SUOT, rho2_SUOT),
a,
b,
num_proj,
p,
numItermax=10,
projections=dir_torch.T,
log=True,
)
A_SUOT, B_SUOT = log["a_reweighted"].T, log["b_reweighted"].T
# Compute USOT
rho1_USOT = 1
rho2_USOT = 1
A_USOT, B_USOT, _ = ot.unbalanced_sliced_ot(
Xs,
Xt,
(rho1_USOT, rho2_USOT),
a,
b,
num_proj,
p,
numItermax=10,
projections=dir_torch.T,
)
Sliced Unbalanced OT
SUOT averages UOT problems on different slices. Depending on the slice, SUOT can keep or get rid of the outlier mode.
get_rot = lambda theta: np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
# visu parameters
nb_slices = 180 # 60
offset_degree = int(180 / nb_slices)
delta_degree = np.pi / nb_slices
colors = plt.cm.Reds(np.linspace(0.3, 1, nb_slices))
X1 = np.array([-4, 0])
X2 = np.array([4, 0])
# max_weights = max(A_SUOT.max(), B_SUOT.max())
pl.figure(1)
def _update_plot(i):
weights_src = A_SUOT[i * offset_degree, :].cpu().numpy()
weights_tgt = B_SUOT[i * offset_degree, :].cpu().numpy()
max_weights = max(weights_src.max(), weights_tgt.max())
min_weights = min(weights_src.min(), weights_tgt.min())
weights_src = 0.1 + 0.9 * (weights_src - min_weights) / (max_weights - min_weights)
weights_tgt = 0.1 + 0.9 * (weights_tgt - min_weights) / (max_weights - min_weights)
R = get_rot(delta_degree * (-i))
X1_r = X1.dot(R)
X2_r = X2.dot(R)
pl.clf()
pl.plot(
[X1_r[0], X2_r[0]], [X1_r[1], X2_r[1]], color=colors[i], alpha=0.8, zorder=0
)
for j in range(len(Xs)):
pl.plot(
[Xs[j, 0], projs_Xps[i * offset_degree, j, 0]],
[Xs[j, 1], projs_Xps[i * offset_degree, j, 1]],
c="blue",
alpha=weights_src[j],
)
for j in range(len(Xt)):
pl.plot(
[Xt[j, 0], projs_Xpt[i * offset_degree, j, 0]],
[Xt[j, 1], projs_Xpt[i * offset_degree, j, 1]],
c="red",
alpha=weights_tgt[j],
)
pl.scatter(
Xs[:, 0],
Xs[:, 1],
s=100 * weights_src,
alpha=weights_src,
zorder=1,
color="blue",
label="Source data",
edgecolor="black",
)
pl.scatter(
Xt[:, 0],
Xt[:, 1],
s=100 * weights_tgt,
alpha=weights_tgt,
zorder=1,
color="red",
label="Target data",
edgecolors="black",
)
pl.xlim(-2.4, 2.4)
pl.ylim(-1, 2.2)
return 1
weights_src = A_SUOT[0, :].cpu().numpy()
weights_tgt = B_SUOT[0, :].cpu().numpy()
max_weights = max(weights_src.max(), weights_tgt.max())
min_weights = min(weights_src.min(), weights_tgt.min())
weights_src = 0.1 + 0.9 * (weights_src - min_weights) / (max_weights - min_weights)
weights_tgt = 0.1 + 0.9 * (weights_tgt - min_weights) / (max_weights - min_weights)
X1_r, X2_r = X1, X2
pl.plot(
[X1_r[0], X2_r[0]],
[X1_r[1], X2_r[1]],
color=colors[0],
alpha=0.8,
zorder=0,
label="Directions",
)
for j in range(len(Xs)):
pl.plot(
[Xs[j, 0], projs_Xps[0, j, 0]],
[Xs[j, 1], projs_Xps[0, j, 1]],
c="blue",
alpha=weights_src[j],
)
for j in range(len(Xt)):
pl.plot(
[Xt[j, 0], projs_Xpt[0, j, 0]],
[Xt[j, 1], projs_Xpt[0, j, 1]],
c="red",
alpha=weights_tgt[j],
)
pl.scatter(
Xs[:, 0],
Xs[:, 1],
s=100 * weights_src,
alpha=weights_src,
zorder=1,
color="blue",
label="Source data",
edgecolor="black",
)
pl.scatter(
Xt[:, 0],
Xt[:, 1],
s=100 * weights_tgt,
alpha=weights_tgt,
zorder=1,
color="red",
label="Target data",
edgecolors="black",
)
pl.xlim(-2.4, 2.4)
pl.ylim(-1, 2.2)
ani = animation.FuncAnimation(
pl.gcf(),
_update_plot,
nb_slices,
interval=100, # , repeat_delay=2000
)