Source code for deepsnap.hetero_gnn

import torch
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torch.nn as nn

from torch import Tensor
from torch._six import container_abcs
from torch_geometric.nn.inits import reset
from torch_sparse import matmul
from typing import (
    List,
    Dict,
)

# TODO: add another new "HeteroSAGEConv" add edge_features
[docs]class HeteroSAGEConv(pyg_nn.MessagePassing): r"""The heterogeneous compitable GraphSAGE operator is derived from the `"Inductive Representation Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_, `"Modeling polypharmacy side effects with graph convolutional networks" <https://arxiv.org/abs/1802.00543>`_, and `"Modeling Relational Data with Graph Convolutional Networks" <https://arxiv.org/abs/1703.06103>`_ papers. .. note:: This layer is usually used with the :class:`HeteroConv`. Args: in_channels_neigh (int): The input dimension of the neighbor node type. out_channels (int): The dimension of the output. in_channels_self (int): The input dimension of the self node type. Default is `None` where the `in_channels_self` is equal to `in_channels_neigh`. remove_self_loop (bool): Whether to remove self loops using :class:`torch_geometric.utils.remove_self_loops`. Default is `True`. """ def __init__(self, in_channels_neigh, out_channels, in_channels_self=None, remove_self_loop=True): super(HeteroSAGEConv, self).__init__(aggr="add") self.remove_self_loop = remove_self_loop self.in_channels_neigh = in_channels_neigh if in_channels_self is None: self.in_channels_self = in_channels_neigh else: self.in_channels_self = in_channels_self self.out_channels = out_channels self.lin_neigh = nn.Linear(self.in_channels_neigh, self.out_channels) self.lin_self = nn.Linear(self.in_channels_self, self.out_channels) self.lin_update = nn.Linear(self.out_channels * 2, self.out_channels)
[docs] def forward( self, node_feature_neigh, node_feature_self, edge_index, edge_weight=None, size=None, res_n_id=None, ): r""" """ if self.remove_self_loop: edge_index, _ = pyg_utils.remove_self_loops(edge_index) return self.propagate( edge_index, size=size, node_feature_neigh=node_feature_neigh, node_feature_self=node_feature_self, edge_weight=edge_weight, res_n_id=res_n_id )
[docs] def message(self, node_feature_neigh_j, node_feature_self_i, edge_weight): r""" """ return node_feature_neigh_j
# torch.cat([node_feature_self_j, edge_feature, node_feature_self_i], dim=...) # TODO: check out homogenous wordnet message passing
[docs] def message_and_aggregate(self, edge_index, node_feature_neigh): r""" This function basically fuses the :meth:`message` and :meth:`aggregate` into one function. It will save memory and avoid message materialization. More information please refer to the PyTorch Geometric documentation. Args: edge_index (:class:`torch_sparse.SparseTensor`): The `edge_index` sparse tensor. node_feature_neigh (:class:`torch.Tensor`): Neighbor feature tensor. """ out = matmul(edge_index, node_feature_neigh, reduce="mean") return out
[docs] def update(self, aggr_out, node_feature_self, res_n_id): r""" """ aggr_out = self.lin_neigh(aggr_out) node_feature_self = self.lin_self(node_feature_self) aggr_out = torch.cat([aggr_out, node_feature_self], dim=-1) aggr_out = self.lin_update(aggr_out) return aggr_out
def __repr__(self): return ( f"{self.__class__.__name__}" f"(neigh: {self.in_channels_neigh}, self: {self.in_channels_self}, " f"out: {self.out_channels})" )
[docs]class HeteroConv(torch.nn.Module): r"""A "wrapper" layer designed for heterogeneous graph layers. It takes a heterogeneous graph layer, such as :class:`deepsnap.hetero_gnn.HeteroSAGEConv`, at the initializing stage. Currently DeepSNAP does not support `parallelize=True`. .. note:: For more detailed use of :class:`HeteroConv`, see the `examples/node_classification_hetero <https://github.com/snap-stanford/deepsnap/tree/master/examples/node_classification_hetero>`_ folder. """ def __init__(self, convs, aggr="add", parallelize=False): super(HeteroConv, self).__init__() assert isinstance(convs, container_abcs.Mapping) self.convs = convs self.modules = torch.nn.ModuleList(convs.values()) assert aggr in ["add", "mean", "max", "mul", "concat", None] self.aggr = aggr if parallelize and torch.cuda.is_available(): self.streams = {key: torch.cuda.Stream() for key in convs.keys()} else: self.streams = None
[docs] def reset_parameters(self): r""" """ for conv in self.convs.values(): reset(conv)
[docs] def forward(self, node_features, edge_indices, edge_features=None): r"""The forward function for :class:`HeteroConv`. Args: node_features (Dict[str, Tensor]): A dictionary each key is node type and the corresponding value is a node feature tensor. edge_indices (Dict[str, Tensor]): A dictionary each key is message type and the corresponding value is an `edge _ndex` tensor. edge_features (Dict[str, Tensor]): A dictionary each key is edge type and the corresponding value is an edge feature tensor. The default value is `None`. """ # TODO: graph is not defined if self.streams is not None and graph.not_in_gpu(): raise RuntimeError("Cannot parallelize on non-gpu graphs") # node embedding computed from each message type message_type_emb = {} for message_key, message_type in edge_indices.items(): if message_key not in self.convs: continue neigh_type, edge_type, self_type = message_key node_feature_neigh = node_features[neigh_type] node_feature_self = node_features[self_type] # TODO: edge_features is not used if edge_features is not None: edge_feature = edge_features[edge_type] edge_index = edge_indices[message_key] # Perform message passing. if self.streams is not None: with torch.cuda.stream(self.streams[message_key]): message_type_emb[message_key] = ( self.convs[message_key]( node_feature_neigh, node_feature_self, edge_index, ) ) else: message_type_emb[message_key] = ( self.convs[message_key]( node_feature_neigh, node_feature_self, edge_index, ) ) if self.streams is not None: torch.cuda.synchronize() # aggregate node embeddings from different message types into 1 node # embedding for each node node_emb = {tail: [] for _, _, tail in message_type_emb.keys()} for (_, _, tail), item in message_type_emb.items(): node_emb[tail].append(item) # Aggregate multiple embeddings with the same tail. for node_type, embs in node_emb.items(): if len(embs) == 1: node_emb[node_type] = embs[0] else: node_emb[node_type] = self.aggregate(embs) return node_emb
[docs] def aggregate(self, xs: List[Tensor]): r"""The aggregation for each node type. Currently support `concat`, `add`, `mean`, `max` and `mul`. Args: xs (List[Tensor]): A list of :class:`torch.Tensor` for a node type. The number of elements in the list equals to the number of `message types` that the destination node type is current node type. """ if self.aggr == "concat": return torch.cat(xs, dim=-1) x = torch.stack(xs, dim=-1) if self.aggr == "add": return x.sum(dim=-1) elif self.aggr == "mean": return x.mean(dim=-1) elif self.aggr == "max": return x.max(dim=-1)[0] elif self.aggr == "mul": return x.prod(dim=-1)[0]
[docs]def forward_op(x, module_dict, **kwargs): r"""A helper function for the heterogeneous operations. Given a dictionary input `x`, it will return a dictionary with the same keys and the values applied by the corresponding values of the `module_dict` with specified parameters. The keys in `x` are same with the keys in the `module_dict`. Args: x (Dict[str, Tensor]): A dictionary that the value of each item is a tensor. module_dict (:class:`torch.nn.ModuleDict`): The value of the `module_dict` will be fed with each value in `x`. **kwargs (optional): Parameters that will be passed into each value of the `module_dict`. """ if not isinstance(x, dict): raise ValueError("The input x should be a dictionary.") res = {} for key in x: res[key] = module_dict[key](x[key], **kwargs) return res
[docs]def loss_op(pred, y, index, loss_func): r""" A helper function for the heterogeneous loss operations. This function will sum the loss of all node types. Args: pred (Dict[str, Tensor]): A dictionary of prediction results. y (Dict[str, Tensor]): A dictionary of labels. The keys should match with the keys in the `pred`. index (Dict[str, Tensor]): A dictionary of indicies that the loss will be computed on. Each value should be :class:`torch.LongTensor`. Notice that `y` will not be indexed by the `index`. Here we assume `y` has been splitted into proper sets. loss_func (callable): The defined loss function. """ loss = 0 for node_type in pred: idx = index[node_type] loss += loss_func(pred[node_type][idx], y[node_type]) return loss