.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/unbalanced-partial/plot_entropic_partial_wasserstein_logscale.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_unbalanced-partial_plot_entropic_partial_wasserstein_logscale.py: ========================================================================== Numerically-stable entropic partial Wasserstein (log-domain solver) ========================================================================== .. note:: Example added in release: 0.9.7. `ot.partial.entropic_partial_wasserstein` is numerically unstable at small regularisation: the iterates underflow to zero and the returned plan contains NaNs (see PythonOT/POT issue #723). This example reproduces the failure mode on a small problem and shows that the log-domain solver, selected with ``entropic_partial_wasserstein(..., method='sinkhorn_log')`` (equivalently :any:`ot.partial.entropic_partial_wasserstein_logscale`), produces a finite plan over the same sweep, agreeing with the original solver at large ``reg`` and degrading gracefully at small ``reg``. Following the :any:`ot.sinkhorn` convention, the solver to use is chosen through the ``method`` parameter: ``'sinkhorn'`` (default) for the classical solver and ``'sinkhorn_log'`` for the log-domain one. The log-domain solver is slower per iteration than the standard one, so the recommendation is to use the standard solver by default and fall back to the log-domain solver when ``reg`` is small enough to risk underflow. .. GENERATED FROM PYTHON SOURCE LINES 26-36 .. code-block:: Python # Author: wzm2256 (original PR #724) # License: MIT License import numpy as np import scipy as sp import matplotlib.pylab as pl import ot .. GENERATED FROM PYTHON SOURCE LINES 37-41 Construct a 50x50 cost matrix ----------------------------- Mirrors the cost-matrix scale (~50) used in PythonOT/POT issue #723. .. GENERATED FROM PYTHON SOURCE LINES 41-52 .. code-block:: Python rng = np.random.RandomState(0) n = 50 xs = rng.rand(n, 2) xt = rng.rand(n, 2) M = sp.spatial.distance.cdist(xs, xt) * 50.0 a = np.ones(n) / n b = np.ones(n) / n m = 0.6 # transport ~60% of the mass .. GENERATED FROM PYTHON SOURCE LINES 53-63 Sweep regularisation -------------------- Run both solvers across a range of ``reg`` values. On this 50×50 problem at cost-scale 50 the standard solver returns NaN at the ``reg`` values closest to the underflow boundary (typically ``reg`` ~0.05–0.01 in our runs, though the exact transition depends on the BLAS / platform's float64 underflow behaviour); the log-domain solver stays finite over the whole sweep, including the very small ``reg`` regime where the standard exp(−M/reg) path would underflow to zero everywhere. .. GENERATED FROM PYTHON SOURCE LINES 63-92 .. code-block:: Python regs = [1.0, 0.5, 0.1, 0.05, 0.01, 5e-3, 1e-3, 5e-4] standard_finite = [] logscale_finite = [] standard_mass = [] logscale_mass = [] for reg in regs: G_std = ot.partial.entropic_partial_wasserstein( a, b, M, reg=reg, m=m, numItermax=2000 ) G_log = ot.partial.entropic_partial_wasserstein( a, b, M, reg=reg, m=m, method="sinkhorn_log", numItermax=2000 ) standard_finite.append(bool(np.isfinite(G_std).all())) logscale_finite.append(bool(np.isfinite(G_log).all())) standard_mass.append(float(G_std.sum()) if np.isfinite(G_std).all() else np.nan) logscale_mass.append(float(G_log.sum())) print( "reg standard_finite logscale_finite std_mass logscale_mass (target m={:.2f})".format( m ) ) for reg, sf, lf, sm, lm in zip( regs, standard_finite, logscale_finite, standard_mass, logscale_mass ): print(f"{reg:>10.4g} {str(sf):<14} {str(lf):<14} {sm:>8.3f} {lm:>8.3f}") .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/ot/partial/partial_solvers.py:620: RuntimeWarning: invalid value encountered in divide q1 = q1 * Kprev / K1 /home/circleci/project/ot/partial/partial_solvers.py:624: RuntimeWarning: invalid value encountered in divide q2 = q2 * K1prev / K2 /home/circleci/project/ot/partial/partial_solvers.py:628: RuntimeWarning: invalid value encountered in divide q3 = q3 * K2prev / K Warning: numerical errors at iteration 1 /home/circleci/project/ot/partial/partial_solvers.py:619: RuntimeWarning: divide by zero encountered in divide K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) /home/circleci/project/ot/partial/partial_solvers.py:619: RuntimeWarning: overflow encountered in divide K1 = nx.dot(nx.diag(nx.minimum(a / nx.sum(K, axis=1), dx)), K) /home/circleci/project/ot/partial/partial_solvers.py:623: RuntimeWarning: divide by zero encountered in divide K2 = nx.dot(K1, nx.diag(nx.minimum(b / nx.sum(K1, axis=0), dy))) Warning: numerical errors at iteration 1 reg standard_finite logscale_finite std_mass logscale_mass (target m=0.60) 1 True True 0.600 0.600 0.5 True True 0.600 0.600 0.1 True True 0.600 0.600 0.05 False True nan 0.600 0.01 False True nan 0.600 0.005 True True 0.600 0.600 0.001 True True 0.600 0.600 0.0005 True True 0.600 0.600 .. GENERATED FROM PYTHON SOURCE LINES 93-95 Plot the resulting plans at large vs. small reg ----------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 95-121 .. code-block:: Python fig, axes = pl.subplots(2, 2, figsize=(9, 8)) for ax, reg in zip(axes[:, 0], (1.0, 0.01)): G_std = ot.partial.entropic_partial_wasserstein( a, b, M, reg=reg, m=m, numItermax=2000 ) if not np.isfinite(G_std).all(): G_std = np.zeros_like(G_std) ax.set_title(f"standard, reg={reg} (NaN)") else: ax.set_title(f"standard, reg={reg}") ax.imshow(G_std, cmap="viridis", aspect="auto") ax.set_xlabel("target") ax.set_ylabel("source") for ax, reg in zip(axes[:, 1], (1.0, 0.01)): G_log = ot.partial.entropic_partial_wasserstein( a, b, M, reg=reg, m=m, method="sinkhorn_log", numItermax=2000 ) ax.set_title(f"logscale, reg={reg}") ax.imshow(G_log, cmap="viridis", aspect="auto") ax.set_xlabel("target") ax.set_ylabel("source") fig.tight_layout() pl.show() .. image-sg:: /auto_examples/unbalanced-partial/images/sphx_glr_plot_entropic_partial_wasserstein_logscale_001.png :alt: standard, reg=1.0, logscale, reg=1.0, standard, reg=0.01 (NaN), logscale, reg=0.01 :srcset: /auto_examples/unbalanced-partial/images/sphx_glr_plot_entropic_partial_wasserstein_logscale_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Warning: numerical errors at iteration 1 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 15.671 seconds) .. _sphx_glr_download_auto_examples_unbalanced-partial_plot_entropic_partial_wasserstein_logscale.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_entropic_partial_wasserstein_logscale.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_entropic_partial_wasserstein_logscale.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_entropic_partial_wasserstein_logscale.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_