.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/backends/plot_ot_batch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_backends_plot_ot_batch.py: ================================================= Solving Many Optimal Transport Problems in Parallel ================================================= In some situations, one may want to solve many OT problems with the same structure (same number of samples, same cost function, etc.) at the same time. In that case using a for loop to solve the problems sequentially is inefficient. This example shows how to use the batch solvers implemented in POT to solve many problems in parallel on CPU or GPU (even more efficient on GPU). .. GENERATED FROM PYTHON SOURCE LINES 14-21 .. code-block:: Python # Author: Paul Krzakala # License: MIT License # sphinx_gallery_thumbnail_number = 1 .. GENERATED FROM PYTHON SOURCE LINES 22-37 Computing the Cost Matrices --------------------------------------------- We want to create a batch of optimal transport problems with :math:`n` samples in :math:`d` dimensions. To do this, we first need to compute the cost matrices for each problem. .. note:: A straightforward approach would be to use a Python loop and :func:`ot.dist`. However, this is inefficient when working with batches. Instead, you can directly use :func:`ot.batch.dist_batch`, which computes all cost matrices in parallel. .. GENERATED FROM PYTHON SOURCE LINES 38-64 .. code-block:: Python import ot import numpy as np n_problems = 4 # nb problems/batch size n_samples = 8 # nb samples dim = 2 # nb dimensions np.random.seed(0) samples_source = np.random.randn(n_problems, n_samples, dim) samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim) # Naive approach M_list = [] for i in range(n_problems): M_list.append( ot.dist(samples_source[i], samples_target[i]) ) # List of cost matrices n_samples x n_samples # Batched approach M_batch = ot.dist_batch( samples_source, samples_target ) # Array of cost matrices n_problems x n_samples x n_samples for i in range(n_problems): assert np.allclose(M_list[i], M_batch[i]) .. GENERATED FROM PYTHON SOURCE LINES 65-78 Solving the Problems --------------------------------------------- Once the cost matrices are computed, we can solve the corresponding optimal transport problems. .. note:: One option is to solve them sequentially with a Python loop using :func:`ot.solve`. This is simple but inefficient for large batches. Instead, you can use :func:`ot.batch.solve_batch`, which solves all problems in parallel. .. GENERATED FROM PYTHON SOURCE LINES 79-98 .. code-block:: Python reg = 1.0 max_iter = 100 tol = 1e-3 # Naive approach results_values_list = [] for i in range(n_problems): res = ot.solve(M_list[i], reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy") results_values_list.append(res.value_linear) # Batched approach results_batch = ot.solve_batch( M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy" ) results_values_batch = results_batch.value_linear assert np.allclose(np.array(results_values_list), results_values_batch, atol=tol * 10) .. GENERATED FROM PYTHON SOURCE LINES 99-107 Comparing Computation Time --------------------------------------------- We now compare the runtime of the two approaches on larger problems. .. note:: The speedup obtained with :mod:`ot.batch` can be even more significant when computations are performed on a GPU. .. GENERATED FROM PYTHON SOURCE LINES 108-148 .. code-block:: Python from time import perf_counter n_problems = 128 n_samples = 8 dim = 2 reg = 10.0 max_iter = 1000 tol = 1e-3 samples_source = np.random.randn(n_problems, n_samples, dim) samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim) def benchmark_naive(samples_source, samples_target): start = perf_counter() for i in range(n_problems): M = ot.dist(samples_source[i], samples_target[i]) res = ot.solve(M, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy") end = perf_counter() return end - start def benchmark_batch(samples_source, samples_target): start = perf_counter() M_batch = ot.dist_batch(samples_source, samples_target) res_batch = ot.solve_batch( M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy" ) end = perf_counter() return end - start time_naive = benchmark_naive(samples_source, samples_target) time_batch = benchmark_batch(samples_source, samples_target) print(f"Naive approach time: {time_naive:.4f} seconds") print(f"Batched approach time: {time_batch:.4f} seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Naive approach time: 0.3529 seconds Batched approach time: 0.0109 seconds .. GENERATED FROM PYTHON SOURCE LINES 149-177 Gromov-Wasserstein --------------------------------------------- The :mod:`ot.batch` module also provides a batched Gromov-Wasserstein solver. .. note:: This solver is **not** equivalent to calling :func:`ot.solve_gromov` repeatedly in a loop. Key differences: - :func:`ot.solve_gromov` Uses the conditional gradient algorithm. Each inner iteration relies on an exact EMD solver. - :func:`ot.batch.solve_gromov_batch` Uses a proximal variant, where each inner iteration applies entropic regularization. As a result: - :func:`ot.solve_gromov` is usually faster on CPU - :func:`ot.batch.solve_gromov_batch` is slower on CPU, but provides better objective values. .. tip:: If your data is on a GPU, :func:`ot.batch.solve_gromov_batch` is significantly faster AND provides better objective values. .. GENERATED FROM PYTHON SOURCE LINES 178-214 .. code-block:: Python from ot import solve_gromov, solve_gromov_batch def benchmark_naive_gw(samples_source, samples_target): start = perf_counter() avg_value = 0 for i in range(n_problems): C1 = ot.dist(samples_source[i], samples_source[i]) C2 = ot.dist(samples_target[i], samples_target[i]) res = solve_gromov(C1, C2, max_iter=1000, tol=tol) avg_value += res.value avg_value /= n_problems end = perf_counter() return end - start, avg_value def benchmark_batch_gw(samples_source, samples_target): start = perf_counter() C1_batch = ot.dist_batch(samples_source, samples_source) C2_batch = ot.dist_batch(samples_target, samples_target) res_batch = solve_gromov_batch( C1_batch, C2_batch, reg=1, max_iter=100, max_iter_inner=50, tol=tol ) avg_value = np.mean(res_batch.value) end = perf_counter() return end - start, avg_value time_naive_gw, avg_value_naive_gw = benchmark_naive_gw(samples_source, samples_target) time_batch_gw, avg_value_batch_gw = benchmark_batch_gw(samples_source, samples_target) print(f"{'Method':<20}{'Time (s)':<15}{'Avg Value':<15}") print(f"{'Naive GW':<20}{time_naive_gw:<15.4f}{avg_value_naive_gw:<15.4f}") print(f"{'Batched GW':<20}{time_batch_gw:<15.4f}{avg_value_batch_gw:<15.4f}") .. rst-class:: sphx-glr-script-out .. code-block:: none Method Time (s) Avg Value Naive GW 0.1095 0.7070 Batched GW 0.4482 0.2914 .. GENERATED FROM PYTHON SOURCE LINES 215-217 In summary: no more for loops! --------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 218-227 .. code-block:: Python import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(4, 4)) ax.text(0.5, 0.5, "For", fontsize=160, ha="center", va="center", zorder=0) ax.axis("off") ax.plot([0, 1], [0, 1], color="red", linewidth=10, zorder=1) ax.plot([0, 1], [1, 0], color="red", linewidth=10, zorder=1) plt.show() .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_ot_batch_001.png :alt: plot ot batch :srcset: /auto_examples/backends/images/sphx_glr_plot_ot_batch_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.980 seconds) .. _sphx_glr_download_auto_examples_backends_plot_ot_batch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_ot_batch.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_ot_batch.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_ot_batch.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_