Source code for ot.gnn._utils

# -*- coding: utf-8 -*-
"""
GNN layers utils
"""

# Author: Sonia Mazelet <sonia.mazelet@ens-paris-saclay.fr>
#         RĂ©mi Flamary <remi.flamary@unice.fr>
#
# License: MIT License

import torch
from ..utils import dist
from ..gromov import fused_gromov_wasserstein2
from ..lp import emd2
from torch_geometric.utils import subgraph


def TFGW_template_initialization(n_tplt, n_tplt_nodes, n_features, feature_init_mean=0., feature_init_std=1.):
    """
    Initializes templates for the Template Fused Gromov Wasserstein layer.
    Returns the adjacency matrices and the features of the nodes of the templates.
    Adjacency matrices are intialised uniformly with values in :math:`[0,1]`.
    Node features are intialized following a normal distribution.

    Parameters
    ----------

      n_tplt: int
        Number of templates.
      n_tplt_nodes: int
        Number of nodes per template.
      n_features: int
        Number of features for the nodes.
      feature_init_mean: float, optional
        Mean of the random normal law to initialize the template features.
      feature_init_std: float, optional
        Standard deviation of the random normal law to initialize the template features.

    Returns
    ----------
      tplt_adjacencies: torch.Tensor, shape (n_templates, n_template_nodes, n_template_nodes)
           Adjancency matrices for the templates.
      tplt_features: torch.Tensor, shape (n_templates, n_template_nodes, n_features)
           Node features for each template.
      q: torch.Tensor, shape (n_templates, n_template_nodes)
           weight on the template nodes.
    """

    tplt_adjacencies = torch.rand((n_tplt, n_tplt_nodes, n_tplt_nodes))
    tplt_features = torch.Tensor(n_tplt, n_tplt_nodes, n_features)

    torch.nn.init.normal_(tplt_features, mean=feature_init_mean, std=feature_init_std)

    q = torch.zeros(n_tplt, n_tplt_nodes)

    tplt_adjacencies = 0.5 * (tplt_adjacencies + torch.transpose(tplt_adjacencies, 1, 2))

    return tplt_adjacencies, tplt_features, q


[docs] def FGW_distance_to_templates(G_edges, tplt_adjacencies, G_features, tplt_features, tplt_weights, alpha=0.5, multi_alpha=False, batch=None): """ Computes the FGW distances between a graph and templates. Parameters ---------- G_edges : torch.Tensor, shape (n_edges, 2) Edge indices of the graph in the Pytorch Geometric format. tplt_adjacencies : list of torch.Tensor, shape (n_templates, n_template_nodes, n_templates_nodes) List of the adjacency matrices of the templates. G_features : torch.Tensor, shape (n_nodes, n_features) Graph node features. tplt_features : list of torch.Tensor, shape (n_templates, n_template_nodes, n_features) List of the node features of the templates. weights : torch.Tensor, shape (n_templates, n_template_nodes) Weights on the nodes of the templates. alpha : float, optional Trade-off parameter (0 < alpha < 1). Weights features (alpha=0) and structure (alpha=1). multi_alpha: bool, optional If True, the alpha parameter is a vector of size n_templates. batch: torch.Tensor, optional Batch vector which assigns each node to its graph. Returns ------- distances : torch.Tensor, shape (n_templates) if batch=None, else shape (n_graphs, n_templates). Vector of fused Gromov-Wasserstein distances between the graph and the templates. """ if batch is None: n, n_feat = G_features.shape n_T, _, n_feat_T = tplt_features.shape weights_G = torch.ones(n) / n C = torch.sparse_coo_tensor(G_edges, torch.ones(len(G_edges[0])), size=(n, n)).type(torch.float) C = C.to_dense() if not n_feat == n_feat_T: raise ValueError('The templates and the graphs must have the same feature dimension.') distances = torch.zeros(n_T) for j in range(n_T): template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) M = dist(G_features, template_features).type(torch.float) #if alpha is zero the emd distance is used if multi_alpha and torch.any(alpha > 0): embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha[j], symmetric=True, max_iter=50) elif not multi_alpha and torch.all(alpha == 0): embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) elif not multi_alpha and alpha > 0: embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha, symmetric=True, max_iter=50) else: embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) distances[j] = embedding else: n_T, _, n_feat_T = tplt_features.shape num_graphs = torch.max(batch) + 1 distances = torch.zeros(num_graphs, n_T) #iterate over the graphs in the batch for i in range(num_graphs): nodes = torch.where(batch == i)[0] G_edges_i, _ = subgraph(nodes, edge_index=G_edges, relabel_nodes=True) G_features_i = G_features[nodes] n, n_feat = G_features_i.shape weights_G = torch.ones(n) / n n_edges = len(G_edges_i[0]) C = torch.sparse_coo_tensor(G_edges_i, torch.ones(n_edges), size=(n, n)).type(torch.float) C = C.to_dense() if not n_feat == n_feat_T: raise ValueError('The templates and the graphs must have the same feature dimension.') for j in range(n_T): template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) M = dist(G_features_i, template_features).type(torch.float) #if alpha is zero the emd distance is used if multi_alpha and torch.any(alpha > 0): embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha[j], symmetric=True, max_iter=50) elif not multi_alpha and torch.all(alpha == 0): embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) elif not multi_alpha and alpha > 0: embedding = fused_gromov_wasserstein2(M, C, tplt_adjacencies[j], weights_G, tplt_weights[j], alpha=alpha, symmetric=True, max_iter=50) else: embedding = emd2(weights_G, tplt_weights[j], M, numItermax=50) distances[i, j] = embedding return distances
[docs] def wasserstein_distance_to_templates(G_features, tplt_features, tplt_weights, batch=None): """ Computes the Wasserstein distances between a graph and graph templates. Parameters ---------- G_features : torch.Tensor, shape (n_nodes, n_features) Node features of the graph. tplt_features : list of torch.Tensor, shape (n_templates, n_template_nodes, n_features) List of the node features of the templates. weights : torch.Tensor, shape (n_templates, n_template_nodes) Weights on the nodes of the templates. batch: torch.Tensor, optional Batch vector which assigns each node to its graph. Returns ------- distances : torch.Tensor, shape (n_templates) if batch=None, else shape (n_graphs, n_templates) Vector of Wasserstein distances between the graph and the templates. """ if batch is None: n, n_feat = G_features.shape n_T, _, n_feat_T = tplt_features.shape weights_G = torch.ones(n) / n if not n_feat == n_feat_T: raise ValueError('The templates and the graphs must have the same feature dimension.') distances = torch.zeros(n_T) for j in range(n_T): template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) M = dist(G_features, template_features).type(torch.float) distances[j] = emd2(weights_G, tplt_weights[j], M, numItermax=50) else: n_T, _, n_feat_T = tplt_features.shape num_graphs = torch.max(batch) + 1 distances = torch.zeros(num_graphs, n_T) #iterate over the graphs in the batch for i in range(num_graphs): nodes = torch.where(batch == i)[0] G_features_i = G_features[nodes] n, n_feat = G_features_i.shape weights_G = torch.ones(n) / n if not n_feat == n_feat_T: raise ValueError('The templates and the graphs must have the same feature dimension.') for j in range(n_T): template_features = tplt_features[j].reshape(len(tplt_features[j]), n_feat_T) M = dist(G_features_i, template_features).type(torch.float) distances[i, j] = emd2(weights_G, tplt_weights[j], M, numItermax=50) return distances