deepsnap.hetero_gnn

Heterogeneous GNN “Wrapper” Layer

class HeteroConv(convs, aggr='add', parallelize=False)[source]

Bases: torch.nn.modules.module.Module

A “wrapper” layer designed for heterogeneous graph layers. It takes a heterogeneous graph layer, such as deepsnap.hetero_gnn.HeteroSAGEConv, at the initializing stage. Currently DeepSNAP does not support parallelize=True.

Note

For more detailed use of HeteroConv, see the examples/node_classification_hetero folder.

aggregate(xs: List[torch.Tensor])[source]

The aggregation for each node type. Currently support concat, add, mean, max and mul.

Parameters

xs (List[Tensor]) – A list of torch.Tensor for a node type. The number of elements in the list equals to the number of message types that the destination node type is current node type.

forward(node_features, edge_indices, edge_features=None)[source]

The forward function for HeteroConv.

Parameters
  • node_features (Dict[str, Tensor]) – A dictionary each key is node type and the corresponding value is a node feature tensor.

  • edge_indices (Dict[str, Tensor]) – A dictionary each key is message type and the corresponding value is an edge _ndex tensor.

  • edge_features (Dict[str, Tensor]) – A dictionary each key is edge type and the corresponding value is an edge feature tensor. The default value is None.

reset_parameters()[source]

Heterogeneous GNN Layers

class HeteroSAGEConv(in_channels_neigh, out_channels, in_channels_self=None, remove_self_loop=True)[source]

Bases: torch_geometric.nn.conv.message_passing.MessagePassing

The heterogeneous compitable GraphSAGE operator is derived from the “Inductive Representation Learning on Large Graphs”, “Modeling polypharmacy side effects with graph convolutional networks”, and “Modeling Relational Data with Graph Convolutional Networks” papers.

Note

This layer is usually used with the HeteroConv.

Parameters
  • in_channels_neigh (int) – The input dimension of the neighbor node type.

  • out_channels (int) – The dimension of the output.

  • in_channels_self (int) – The input dimension of the self node type. Default is None where the in_channels_self is equal to in_channels_neigh.

  • remove_self_loop (bool) – Whether to remove self loops using torch_geometric.utils.remove_self_loops. Default is True.

forward(node_feature_neigh, node_feature_self, edge_index, edge_weight=None, size=None, res_n_id=None)[source]
message(node_feature_neigh_j, node_feature_self_i, edge_weight)[source]
message_and_aggregate(edge_index, node_feature_neigh)[source]

This function basically fuses the message() and aggregate() into one function. It will save memory and avoid message materialization. More information please refer to the PyTorch Geometric documentation.

Parameters
  • edge_index (torch_sparse.SparseTensor) – The edge_index sparse tensor.

  • node_feature_neigh (torch.Tensor) – Neighbor feature tensor.

update(aggr_out, node_feature_self, res_n_id)[source]

Heterogeneous GNN Functions

forward_op(x, module_dict, **kwargs)[source]

A helper function for the heterogeneous operations. Given a dictionary input x, it will return a dictionary with the same keys and the values applied by the corresponding values of the module_dict with specified parameters. The keys in x are same with the keys in the module_dict.

Parameters
  • x (Dict[str, Tensor]) – A dictionary that the value of each item is a tensor.

  • module_dict (torch.nn.ModuleDict) – The value of the module_dict will be fed with each value in x.

  • **kwargs (optional) – Parameters that will be passed into each value of the module_dict.

loss_op(pred, y, index, loss_func)[source]

A helper function for the heterogeneous loss operations. This function will sum the loss of all node types.

Parameters
  • pred (Dict[str, Tensor]) – A dictionary of prediction results.

  • y (Dict[str, Tensor]) – A dictionary of labels. The keys should match with the keys in the pred.

  • index (Dict[str, Tensor]) – A dictionary of indicies that the loss will be computed on. Each value should be torch.LongTensor. Notice that y will not be indexed by the index. Here we assume y has been splitted into proper sets.

  • loss_func (callable) – The defined loss function.