.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/backends/plot_gradient_descent.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_backends_plot_gradient_descent.py: =============================================================================== Solve Fused Unbalanced Gromov Wasserstein with Adam =============================================================================== Since the FUGW loss is differentiable, it can be minimized with first-order optimization. We show how to do this with the `loss_fugw_batch` function and compare the results with the dedicated FUGW solver `fused_unbalanced_gromov_wasserstein`. .. GENERATED FROM PYTHON SOURCE LINES 11-29 .. code-block:: Python # Author: Rémi Flamary # Sonia Mazelet # # License: MIT License # sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pylab as pl import torch from time import perf_counter import ot from ot.batch._quadratic import loss_quadratic_batch, tensor_batch from ot.gromov import fused_unbalanced_gromov_wasserstein from sklearn.manifold import MDS .. GENERATED FROM PYTHON SOURCE LINES 30-32 Generation of source and target graphs ---------------- .. GENERATED FROM PYTHON SOURCE LINES 32-123 .. code-block:: Python rng = np.random.RandomState(42) def get_sbm(n, nc, ratio, P): nbpc = np.round(n * ratio).astype(int) n = np.sum(nbpc) C = np.zeros((n, n)) for c1 in range(nc): for c2 in range(c1 + 1): if c1 == c2: for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])): for j in range(np.sum(nbpc[:c2]), i): if rng.rand() <= P[c1, c2]: C[i, j] = 1 else: for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])): for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[: c2 + 1])): if rng.rand() <= P[c1, c2]: C[i, j] = 1 return C + C.T def plot_graph(x, C, color="C0", s=100): for j in range(C.shape[0]): for i in range(j): if C[i, j] > 0: pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k") pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k") def get_sbm_labels(n, ratio): nbpc = np.round(n * ratio).astype(int) return np.concatenate( [np.full(count, label, dtype=int) for label, count in enumerate(nbpc)] ) def get_noisy_one_hot(labels, n_classes, noise_level=0.1): x = np.eye(n_classes)[labels] x += noise_level * rng.randn(*x.shape) return x n1 = 15 n2 = 10 nc1 = 3 nc2 = 2 ratio1 = np.array([0.33, 0.33, 0.33]) ratio2 = np.array([0.5, 0.5]) P1 = np.array([[0.8, 0.03, 0.0], [0.08, 0.8, 0.03], [0.0, 0.08, 0.8]]) P2 = np.array(0.8 * np.eye(2) + 0.01 * np.ones((2, 2))) C1 = get_sbm(n1, nc1, ratio1, P1) C2 = get_sbm(n2, nc2, ratio2, P2) labels1 = get_sbm_labels(n1, ratio1) labels2 = get_sbm_labels(n2, ratio2) # Use noisy one-hot encodings of the SBM classes as node features. feature_dim = max(nc1, nc2) x1 = get_noisy_one_hot(labels1, feature_dim) x2 = get_noisy_one_hot(labels2, feature_dim) all_features = np.vstack([x1, x2]) feature_min = all_features[:, :3].min(axis=0, keepdims=True) feature_max = all_features[:, :3].max(axis=0, keepdims=True) # get 2d positions for visualization pos1 = MDS(dissimilarity="precomputed", random_state=0, n_init=1).fit_transform(1 - C1) pos2 = MDS(dissimilarity="precomputed", random_state=0, n_init=1).fit_transform(1 - C2) colors1 = np.clip( (x1 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0 ) colors2 = np.clip( (x2 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0 ) pl.figure(1, (10, 5)) pl.clf() pl.subplot(1, 2, 1) plot_graph(pos1, C1, color=colors1) pl.title("SBM source graph") pl.axis("off") pl.subplot(1, 2, 2) plot_graph(pos2, C2, color=colors2) pl.title("SBM target graph") _ = pl.axis("off") .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_gradient_descent_001.png :alt: SBM source graph, SBM target graph :srcset: /auto_examples/backends/images/sphx_glr_plot_gradient_descent_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/.local/lib/python3.12/site-packages/sklearn/manifold/_mds.py:735: FutureWarning: The default value of `init` will change from 'random' to 'classical_mds' in 1.10. To suppress this warning, provide some value of `init`. warnings.warn( /home/circleci/.local/lib/python3.12/site-packages/sklearn/manifold/_mds.py:752: FutureWarning: The `dissimilarity` parameter is deprecated and will be removed in 1.10. Use `metric` instead. warnings.warn( /home/circleci/.local/lib/python3.12/site-packages/sklearn/manifold/_mds.py:735: FutureWarning: The default value of `init` will change from 'random' to 'classical_mds' in 1.10. To suppress this warning, provide some value of `init`. warnings.warn( /home/circleci/.local/lib/python3.12/site-packages/sklearn/manifold/_mds.py:752: FutureWarning: The `dissimilarity` parameter is deprecated and will be removed in 1.10. Use `metric` instead. warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 124-126 Solve FUGW with Adam ---------------- .. GENERATED FROM PYTHON SOURCE LINES 126-187 .. code-block:: Python # Even though `loss_fugw_batch` supports batches of problems, we use a # batch of size 1 here for clarity. a = ot.unif(C1.shape[0]) b = ot.unif(C2.shape[0]) M = ot.dist(x1, x2) M /= M.max() a_torch = torch.tensor(a[None, :]) b_torch = torch.tensor(b[None, :]) C1_torch = torch.tensor(C1[None, :, :]) C2_torch = torch.tensor(C2[None, :, :]) M_torch = torch.tensor(M[None, :, :]) L = tensor_batch(a_torch, b_torch, C1_torch, C2_torch, loss="sqeuclidean") alpha = 0.5 reg_marginals = 0.5 lr = 5e-2 nb_iter_max = 1500 tol = 1e-7 T0_torch = a_torch[:, :, None] * b_torch[:, None, :] T_torch = torch.log(torch.expm1(T0_torch)).clone().requires_grad_(True) optimizer = torch.optim.Adam([T_torch], lr=lr) loss_iter = [] mass_iter = [] previous_plan_torch = None tic = perf_counter() for i in range(nb_iter_max): optimizer.zero_grad() # Positive transport plan parameterized as log(1 + exp(T)). plan_torch = torch.nn.functional.softplus(T_torch) loss = loss_quadratic_batch( a_torch, b_torch, C1_torch, C2_torch, plan_torch, M_torch, alpha=alpha, unbalanced=reg_marginals, unbalanced_type="kl", recompute_const=True, )[0] loss_iter.append(float(loss.detach())) mass_iter.append(float(plan_torch.detach().sum())) if previous_plan_torch is not None: err = float(torch.sum(torch.abs(plan_torch.detach() - previous_plan_torch))) if err < tol: break previous_plan_torch = plan_torch.detach().clone() loss.backward() optimizer.step() time_adam = perf_counter() - tic T_adam = torch.nn.functional.softplus(T_torch).detach().cpu().numpy()[0] .. GENERATED FROM PYTHON SOURCE LINES 188-194 Compare with the dedicated FUGW solver ------------------------------------- The dedicated solver uses a block coordinate descent (BCD) scheme. We compare the coupling it returns with the one obtained by direct Adam minimization of `loss_fugw_batch`. .. GENERATED FROM PYTHON SOURCE LINES 194-242 .. code-block:: Python def evaluate_batch_fugw_loss(plan): plan_torch = torch.tensor(plan[None, :, :], dtype=M_torch.dtype) loss = loss_quadratic_batch( a_torch, b_torch, C1_torch, C2_torch, plan_torch, M_torch, alpha=alpha, unbalanced=reg_marginals, unbalanced_type="kl", recompute_const=True, )[0] return float(loss.detach()) tic = perf_counter() result = ot.solve_gromov( C1, C2, M, a, b, alpha=alpha, reg=0, unbalanced_type="kl", unbalanced=reg_marginals ) time_bcd = perf_counter() - tic loss_adam_final = evaluate_batch_fugw_loss(T_adam) T_bcd = result.plan loss_bcd_final = evaluate_batch_fugw_loss(T_bcd) mass_bcd = T_bcd.sum() pl.figure(2, (10, 4)) pl.clf() pl.subplot(1, 2, 1) pl.plot(loss_iter, label="Adam") pl.axhline(loss_bcd_final, color="C1", linestyle="--", label="BCD solver") pl.grid() pl.title("FUGW loss along iterations") pl.xlabel("Iterations") pl.legend() pl.subplot(1, 2, 2) pl.plot(mass_iter, label="Adam") pl.axhline(mass_bcd, color="C1", linestyle="--", label="BCD solver") pl.grid() pl.title("Transport mass") pl.xlabel("Iterations") _ = pl.legend() .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_gradient_descent_002.png :alt: FUGW loss along iterations, Transport mass :srcset: /auto_examples/backends/images/sphx_glr_plot_gradient_descent_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 243-248 Visualize the learned couplings ------------------------------- We visualize the couplings obtained by both methods to compare them. On this example, both methods recover similar couplings, but direct minimization reaches a lower `loss_fugw_batch` value at the cost of a longer runtime. .. GENERATED FROM PYTHON SOURCE LINES 248-267 .. code-block:: Python vmin = min(T_adam.min(), T_bcd.min()) vmax = max(T_adam.max(), T_bcd.max()) pl.figure(3, (10, 4)) pl.clf() pl.subplot(1, 2, 1) pl.imshow(T_adam, interpolation="nearest", cmap="Blues", vmin=vmin, vmax=vmax) pl.title( f"Coupling from direct minimization\nloss={loss_adam_final:.3f}, time={time_adam:.2f}s" ) pl.xlabel("Target nodes") pl.ylabel("Source nodes") pl.colorbar() pl.subplot(1, 2, 2) pl.imshow(T_bcd, interpolation="nearest", cmap="Blues", vmin=vmin, vmax=vmax) pl.title(f"Coupling from BCD solver\nloss={loss_bcd_final:.3f}, time={time_bcd:.2f}s") pl.xlabel("Target nodes") pl.ylabel("Source nodes") _ = pl.colorbar() .. image-sg:: /auto_examples/backends/images/sphx_glr_plot_gradient_descent_003.png :alt: Coupling from direct minimization loss=0.147, time=3.87s, Coupling from BCD solver loss=0.146, time=0.52s :srcset: /auto_examples/backends/images/sphx_glr_plot_gradient_descent_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 5.076 seconds) .. _sphx_glr_download_auto_examples_backends_plot_gradient_descent.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gradient_descent.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gradient_descent.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_gradient_descent.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_