.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_gromov_plot_semirelaxed_gromov_wasserstein_barycenter.py: ===================================================== Semi-relaxed (Fused) Gromov-Wasserstein Barycenter as Dictionary Learning ===================================================== In this example, we illustrate how to learn a semi-relaxed Gromov-Wasserstein (srGW) barycenter using a Block-Coordinate Descent algorithm, on a dataset of structured data such as graphs, denoted :math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights :math:`\{ \mathbf{p_s} \}_{s \in [S]}`. Given a barycenter structure matrix :math:`\mathbf{C}` with N nodes, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` is modeled as a reweighed subgraph with structure :math:`\mathbf{C}` and weights :math:`\mathbf{w_s} \in \Sigma_N` where each :math:`\mathbf{w_s}` corresponds to the second marginal of the OT :math:`\mathbf{T_s}` (s.t :math:`\mathbf{w_s} = \mathbf{T_s}^\top \mathbf{1}`) minimizing the srGW loss between the s^{th} input and the barycenter. First, we consider a dataset composed of graphs generated by Stochastic Block models with variable sizes taken in :math:`\{30, ... , 50\}` and number of clusters varying in :math:`\{ 1, 2, 3\}` with random proportions. We learn a srGW barycenter with 3 nodes and visualize the learned structure and the embeddings for some inputs. Second, we illustrate the extension of this framework to graphs endowed with node features by using the semi-relaxed Fused Gromov-Wasserstein divergence (srFGW). Starting from the aforementioned dataset of unattributed graphs, we add discrete labels uniformly depending on the number of clusters. Then conduct the analog analysis. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs". International Conference on Learning Representations (ICLR), 2022. .. GENERATED FROM PYTHON SOURCE LINES 37-51 .. code-block:: Python # Author: Cédric Vincent-Cuaz # # License: MIT License # sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl from sklearn.manifold import MDS from ot.gromov import semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters import ot import networkx from networkx.generators.community import stochastic_block_model as sbm .. GENERATED FROM PYTHON SOURCE LINES 52-54 Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. ----------------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 55-130 .. code-block:: Python np.random.seed(42) n_samples = 60 # number of graphs in the dataset # For every number of clusters, we generate SBM with fixed inter/intra-clusters probability, # and variable cluster proportions. clusters = [1, 2, 3] Nc = n_samples // len(clusters) # number of graphs by cluster nlabels = len(clusters) dataset = [] node_labels = [] labels = [] p_inter = 0.1 p_intra = 0.9 for n_cluster in clusters: for i in range(Nc): n_nodes = int(np.random.uniform(low=30, high=50)) if n_cluster > 1: P = p_inter * np.ones((n_cluster, n_cluster)) np.fill_diagonal(P, p_intra) props = np.random.uniform(0.2, 1, size=(n_cluster,)) props /= props.sum() sizes = np.round(n_nodes * props).astype(np.int32) else: P = p_intra * np.eye(1) sizes = [n_nodes] G = sbm(sizes, P, seed=i, directed=False) part = np.array([G.nodes[i]["block"] for i in range(np.sum(sizes))]) C = networkx.to_numpy_array(G) dataset.append(C) node_labels.append(part) labels.append(n_cluster) # Visualize samples def plot_graph(x, C, binary=True, color="C0", s=None): for j in range(C.shape[0]): for i in range(j): if binary: if C[i, j] > 0: pl.plot( [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k" ) else: # connection intensity proportional to C[i,j] pl.plot( [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color="k" ) pl.scatter( x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k", cmap="tab10", vmax=9 ) pl.figure(1, (12, 8)) pl.clf() for idx_c, c in enumerate(clusters): C = dataset[(c - 1) * Nc] # sample with c clusters # get 2d position for nodes x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C) pl.subplot(2, nlabels, c) pl.title("(graph) sample from label " + str(c), fontsize=14) plot_graph(x, C, binary=True, color="C0", s=50.0) pl.axis("off") pl.subplot(2, nlabels, nlabels + c) pl.title("(matrix) sample from label %s \n" % c, fontsize=14) pl.imshow(C, interpolation="nearest") pl.axis("off") pl.tight_layout() pl.show() .. image-sg:: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_001.png :alt: (graph) sample from label 1, (matrix) sample from label 1 , (graph) sample from label 2, (matrix) sample from label 2 , (graph) sample from label 3, (matrix) sample from label 3 :srcset: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py:108: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'vmax' will be ignored pl.scatter( /home/circleci/project/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py:108: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'vmax' will be ignored pl.scatter( /home/circleci/project/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py:108: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'vmax' will be ignored pl.scatter( .. GENERATED FROM PYTHON SOURCE LINES 131-133 Estimate the srGW barycenter from the dataset and visualize embeddings ----------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 134-215 .. code-block:: Python np.random.seed(0) ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on input nodes lambdas = [1.0 / n_samples for _ in range(n_samples)] # uniform barycenter N = 3 # 3 nodes in the barycenter # Here we use the Fluid partitioning method to deduce initial transport plans # for the barycenter problem. An initlal structure is also deduced from these # initial transport plans. Then a warmstart strategy is used iteratively to # init each individual srGW problem within the BCD algorithm. init_plan = "fluid" # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan` warmstartT = True C, log = semirelaxed_gromov_barycenters( N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun="square_loss", tol=1e-6, stop_criterion="loss", warmstartT=warmstartT, log=True, G0=init_plan, verbose=False, ) print("barycenter structure:", C) unmixings = log["p"] # Compute the 2D representation of the embeddings living in the 2-simplex of probability unmixings2D = np.zeros(shape=(n_samples, 2)) for i, w in enumerate(unmixings): unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0 unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0 x = [0.0, 0.0] y = [1.0, 0.0] z = [0.5, np.sqrt(3) / 2.0] extremities = np.stack([x, y, z]) pl.figure(2, (4, 4)) pl.clf() pl.title("Embedding space", fontsize=14) for cluster in range(nlabels): start, end = Nc * cluster, Nc * (cluster + 1) if cluster == 0: pl.scatter( unmixings2D[start:end, 0], unmixings2D[start:end, 1], c="C" + str(cluster), marker="o", s=80.0, label="1 cluster", ) else: pl.scatter( unmixings2D[start:end, 0], unmixings2D[start:end, 1], c="C" + str(cluster), marker="o", s=80.0, label="%s clusters" % (cluster + 1), ) pl.scatter( extremities[:, 0], extremities[:, 1], c="black", marker="x", s=100.0, label="bary. nodes", ) pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0) pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0) pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0) pl.axis("off") pl.legend(fontsize=11) pl.tight_layout() pl.show() .. image-sg:: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_002.png :alt: Embedding space :srcset: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none barycenter structure: [[0.85792315 0.09415624 0.10183026] [0.09415624 0.86111455 0.09633968] [0.10183026 0.09633968 0.84805335]] .. GENERATED FROM PYTHON SOURCE LINES 216-220 Endow the dataset with node features ------------------------------------ node labels, corresponding to the true SBM cluster assignments, are set for each graph as one-hot encoded node features. .. GENERATED FROM PYTHON SOURCE LINES 221-248 .. code-block:: Python dataset_features = [] for i in range(len(dataset)): n = dataset[i].shape[0] F = np.zeros((n, 3)) F[np.arange(n), node_labels[i]] = 1.0 dataset_features.append(F) pl.figure(3, (12, 8)) pl.clf() for idx_c, c in enumerate(clusters): C = dataset[(c - 1) * Nc] # sample with c clusters F = dataset_features[(c - 1) * Nc] colors = [f"C{labels[i]}" for i in range(F.shape[0])] # get 2d position for nodes x = MDS(dissimilarity="precomputed", random_state=0).fit_transform(1 - C) pl.subplot(2, nlabels, c) pl.title("(graph) sample from label " + str(c), fontsize=14) plot_graph(x, C, binary=True, color=colors, s=50) pl.axis("off") pl.subplot(2, nlabels, nlabels + c) pl.title("(matrix) sample from label %s \n" % c, fontsize=14) pl.imshow(C, interpolation="nearest") pl.axis("off") pl.tight_layout() pl.show() .. image-sg:: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_003.png :alt: (graph) sample from label 1, (matrix) sample from label 1 , (graph) sample from label 2, (matrix) sample from label 2 , (graph) sample from label 3, (matrix) sample from label 3 :srcset: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/project/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py:108: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'vmax' will be ignored pl.scatter( /home/circleci/project/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py:108: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'vmax' will be ignored pl.scatter( /home/circleci/project/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py:108: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'vmax' will be ignored pl.scatter( .. GENERATED FROM PYTHON SOURCE LINES 249-255 Estimate the srFGW barycenter from the attributed graphs and visualize embeddings ----------------------------------------------------------- We emphasize the dependence to the trade-off parameter alpha that weights the relative importance between structures (alpha=1) and features (alpha=0), knowing that embeddings that perfectly cluster graphs w.r.t their features should ease the identification of the number of clusters in the graphs. .. GENERATED FROM PYTHON SOURCE LINES 256-334 .. code-block:: Python list_alphas = [0.0001, 0.5, 0.9999] list_unmixings2D = [] for ialpha, alpha in enumerate(list_alphas): print("--- alpha:", alpha) C, F, log = semirelaxed_fgw_barycenters( N=N, Ys=dataset_features, Cs=dataset, ps=ps, lambdas=lambdas, alpha=alpha, loss_fun="square_loss", tol=1e-6, stop_criterion="loss", warmstartT=warmstartT, log=True, G0=init_plan, ) print("barycenter structure:", C) print("barycenter features:", F) unmixings = log["p"] # Compute the 2D representation of the embeddings living in the 2-simplex of probability unmixings2D = np.zeros(shape=(n_samples, 2)) for i, w in enumerate(unmixings): unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0 unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0 list_unmixings2D.append(unmixings2D.copy()) x = [0.0, 0.0] y = [1.0, 0.0] z = [0.5, np.sqrt(3) / 2.0] extremities = np.stack([x, y, z]) pl.figure(4, (12, 4)) pl.clf() pl.suptitle("Embedding spaces", fontsize=14) for ialpha, alpha in enumerate(list_alphas): pl.subplot(1, len(list_alphas), ialpha + 1) pl.title(f"alpha = {alpha}", fontsize=14) for cluster in range(nlabels): start, end = Nc * cluster, Nc * (cluster + 1) if cluster == 0: pl.scatter( list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c="C" + str(cluster), marker="o", s=80.0, label="1 cluster", ) else: pl.scatter( list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c="C" + str(cluster), marker="o", s=80.0, label="%s clusters" % (cluster + 1), ) pl.scatter( extremities[:, 0], extremities[:, 1], c="black", marker="x", s=100.0, label="bary. nodes", ) pl.plot([x[0], y[0]], [x[1], y[1]], color="black", linewidth=2.0) pl.plot([x[0], z[0]], [x[1], z[1]], color="black", linewidth=2.0) pl.plot([y[0], z[0]], [y[1], z[1]], color="black", linewidth=2.0) pl.axis("off") pl.legend(fontsize=11) pl.tight_layout() pl.show() .. image-sg:: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_004.png :alt: Embedding spaces, alpha = 0.0001, alpha = 0.5, alpha = 0.9999 :srcset: /auto_examples/gromov/images/sphx_glr_plot_semirelaxed_gromov_wasserstein_barycenter_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none --- alpha: 0.0001 barycenter structure: [[1. 0. 0.] [0. 1. 0.] [0. 0. 1.]] barycenter features: [[0.86400199 0.09381731 0.10469915] [0.09381731 0.84251959 0.09963309] [0.10469915 0.09963309 0.81742118]] --- alpha: 0.5 barycenter structure: [[0. 1. 0.] [1. 0. 0.] [0. 0. 1.]] barycenter features: [[0.84251959 0.09381731 0.09963309] [0.09381731 0.86400199 0.10469915] [0.09963309 0.10469915 0.81742118]] --- alpha: 0.9999 barycenter structure: [[0.57086318 0.25224954 0.17688729] [0.62370377 0.29125323 0.08504301] [0.62892887 0.31328417 0.05778696]] barycenter features: [[0.85504036 0.10042125 0.0932431 ] [0.10042125 0.86156282 0.09962626] [0.0932431 0.09962626 0.85292267]] .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.568 seconds) .. _sphx_glr_download_auto_examples_gromov_plot_semirelaxed_gromov_wasserstein_barycenter.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_semirelaxed_gromov_wasserstein_barycenter.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_semirelaxed_gromov_wasserstein_barycenter.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_semirelaxed_gromov_wasserstein_barycenter.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_