import copy
import math
import types
import random
import networkx as nx
import numpy as np
import torch
import deepsnap
from deepsnap.graph import Graph
from deepsnap.hetero_graph import HeteroGraph
import pdb
from typing import (
Dict,
List,
Union
)
import warnings
[docs]class Generator(object):
r"""
Abstract class of on the fly generator used in the dataset.
It generates on the fly graphs, which will be fed into the model.
"""
def __init__(self, sizes, size_prob=None, dataset_len=0):
self.sizes = sizes
if sizes is not None:
if size_prob is None:
self.size_prob = np.ones(len(sizes)) / len(sizes)
else:
self.size_prob = size_prob
# by default length of generator dataset is 0
self._len = dataset_len
def __len__(self):
return self._len
def set_len(self, dataset_len):
self._len = dataset_len
def _get_size(self, size=None):
if size is None:
return np.random.choice(
self.sizes, size=1, replace=True, p=self.size_prob
)[0]
else:
return size
@property
def num_node_labels(self):
return 0
@property
def num_nodes(self):
return 0
@property
def num_edge_labels(self):
return 0
@property
def num_edges(self):
return 0
@property
def num_graph_labels(self):
return 0
[docs] def generate(self):
r"""
Overwrite in subclass. Generates and returns a
:class:`deepsnap.graph.Graph` object
Returns:
:class:`deepsnap.graph.Graph`: A DeepSNAP graph object.
"""
return Graph(nx.Graph())
[docs]class EnsembleGenerator(Generator):
def __init__(self, generators, gen_prob=None, dataset_len=0):
r"""
A generator that is an ensemble of many generators.
Args:
generators (List[:class:`Generator`]): A list of Generators.
gen_prob (Array like): An array like (list) probabilities with
the same length as generators. It specifies the probability
of sampling from each generator. If it is `None`, the
:class:`EnsembleGenerator` will uniformly sample a generator.
"""
super(EnsembleGenerator, self).__init__(None, dataset_len=dataset_len)
if gen_prob is None:
self.gen_prob = np.ones(len(generators)) / len(generators)
else:
self.gen_prob = gen_prob
self.generators = generators
@property
def num_node_labels(self):
r"""
Returns number of the node labels in the generated graphs.
Returns:
int: The number of node labels.
"""
# TODO: change to unique as what we did in graph.py
return max([gen.num_node_labels for gen in self.generators])
@property
def num_nodes(self):
r"""
Returns number of the nodes in each generated graphs.
Returns:
list: List of the number of nodes.
"""
return [gen.num_nodes for gen in self.generators]
@property
def num_edge_labels(self):
r"""
Returns number of the edge labels in the generated graphs.
Returns:
int: The number of edge labels.
"""
# TODO: change to unique as what we did in graph.py
return max([gen.num_edge_labels for gen in self.generators])
@property
def num_edges(self):
r"""
Returns number of the edges in each generated graphs.
Returns:
list: List of the number of edges.
"""
return [gen.num_edges for gen in self.generators]
@property
def num_graph_labels(self):
r"""
Returns number of the graph labels in the generated graphs.
Returns:
int: The number of graph labels.
"""
# TODO: change to unique as what we did in graph.py
return max([gen.num_graph_labels for gen in self.generators])
[docs] def generate(self, **kwargs):
r"""
Generate a list of graphs.
Returns:
list: Generated a list of :class:`deepsnap.graph.Graph` objects.
"""
gen = np.random.choice(self.generators, 1, p=self.gen_prob)[0]
return gen.generate(**kwargs)
[docs]class GraphDataset(object):
r"""
A plain python object modeling a list of :class:`deepsnap.graph.Graph`
objects with various (optional) attributes.
Args:
graphs (list, optional): A list of :class:`deepsnap.graph.Graph`.
task (str): The task that this :class:`GraphDataset` is used for
(task = `node` or `edge` or `link_pred` or `graph`).
custom_split_graphs (list): A list of 2 (train and val)
or 3 (train, val and test) lists of splitted graphs, used in
custom split of the `graph` task.
edge_negative_sampling_ratio (float): The number of negative
samples compared to that of positive edges. Default value
is 1.
edge_message_ratio (float): The number of message passing edges
compared to that of training supervision edges. Default value
is 0.8.
edge_train_mode (str): Use `all` or `disjoint`. In `all`
mode, training supervision edges are same with the message
passing edges. In `disjoint` mode, training supervision
objectives are different from the message passing edges.
The difference between these two modes please see
the `DeepSNAP link prediction Colab <https://colab.research.
google.com/drive/1ycdlJuse7l2De7wi51lFd_nCuaWgVABc?
usp=sharing>`_.
edge_split_mode (str): Use `exact` or `approximate`. This mode is
designed for the heterogeneous graph. If the mode is `exact`,
split the heterogeneous graph according to both the ratio
and the split type. If the mode is `approximate`, split the
heterogeneous graph regardless of the split type.
minimum_node_per_graph (int): If the number of nodes of a graph
is smaller than the minimum node per graph, that graph will
be filtered out.
generator (:class:`Generator`): The dataset will be on-the-fly
generated. The on-the-fly generator will be used, if the
:obj:`self.graphs` is empty or `None`, and the generator
(:class:`Generator`) is provided with an overwritten
:meth:`generate` method.
resample_negatives (bool): Whether to resample negative edges in
each iteration of the `link_pred` task. User needs to set this
variable in the case of tensor backend for the custom split.
resample_disjoint (bool): Whether to resample disjoint training
edges in the `disjonint` `link_pred` task.
resample_disjoint_period (int): The number of iterations after
which the training edges in the `disjoint` mode are resampled.
negative_label_val (int, optional): The value of negative edges
generated in link_pred task. User needs to set this variable
in the case of tensor backend custom split.
netlib (types.ModuleType, optional): The graph backend module.
Currently DeepSNAP supports the NetworkX and SnapX (for
SnapX only the undirected homogeneous graph) as the graph
backend. Default graph backend is the NetworkX.
"""
def __init__(
self,
graphs: List[Graph] = None,
task: str = "node",
custom_split_graphs: List[Graph] = None,
edge_negative_sampling_ratio: float = 1,
edge_message_ratio: float = 0.8,
edge_train_mode: str = "all",
edge_split_mode: str = "exact",
minimum_node_per_graph: int = 5,
generator=None,
resample_negatives: bool = False,
resample_disjoint: bool = False,
resample_disjoint_period: int = 1,
negative_label_val: int = None,
netlib=None
):
if netlib is not None:
deepsnap._netlib = netlib
if graphs is not None:
# make sure graphs is a list
if not isinstance(graphs, list):
graphs = [graphs]
# support user' input a list of netlib.Graph instead of Graph
for i, graph in enumerate(graphs):
if not isinstance(graph, Graph):
graphs[i] = Graph(graph, netlib=netlib)
# validity check for `task`
if task not in ["node", "edge", "link_pred", "graph"]:
raise ValueError(
"task must be one of node, edge, link_pred or graph."
)
# validity check for `edge_train_mode`
if edge_train_mode not in ["all", "disjoint"]:
raise ValueError("edge_train_mode must be all or disjoint.")
# validity check for `edge_split_mode`
if edge_split_mode not in ["exact", "approximate"]:
raise ValueError(
"edge_split_mode must be exact or approximate."
)
# validity check for `resample_negatives`
if resample_negatives and (task != "link_pred"):
raise ValueError(
"resample_negatives set to True only make sense "
"when self.task is link_pred."
)
# validity check for `resample_disjoint`
# if resample_negatives and (self.task != "link_pred"):
if (
resample_disjoint
and (
(task != "link_pred")
or (edge_train_mode != "disjoint")
)
):
raise ValueError(
"resample_disjoint set to True only make sense "
"when self.task is `link_pred` and edge_train_mode is "
"disjoint."
)
# validity check for `resample_negatives`
if (negative_label_val is not None) and (task != "link_pred"):
raise ValueError(
"negative_label_val is set only make sense "
"when self.task is link_pred."
)
# parameter initialization
self.graphs = graphs
self.task = task
self.custom_split_graphs = custom_split_graphs
self.edge_message_ratio = edge_message_ratio
self.edge_negative_sampling_ratio = edge_negative_sampling_ratio
self.edge_train_mode = edge_train_mode
self.edge_split_mode = edge_split_mode
self.minimum_node_per_graph = minimum_node_per_graph
self.resample_negatives = resample_negatives
self.resample_disjoint = resample_disjoint
self.resample_disjoint_period = resample_disjoint_period
self.negative_label_val = negative_label_val
# set private parameters
self._split_types = None
self._is_tensor = False
# graphs preprocessing
if graphs is None or len(graphs) == 0:
if generator is None:
raise ValueError("Graphs are None")
else:
# on-the-fly dataset
self.generator = generator
self.graphs = None
self.otf_device = None
elif generator is not None:
raise ValueError(
"Both graphs and on-the-fly generator are "
"provided (only one should be provided."
)
else:
# by default on-the-fly generator is not used
# when generator is not provide
self.generator = None
# filter graphs that are too small
if self.minimum_node_per_graph > 0:
graphs_filter = []
for idx, graph in enumerate(self.graphs):
if isinstance(graph, Graph):
if isinstance(graph, HeteroGraph):
if (
sum(graph.num_nodes().values())
>= self.minimum_node_per_graph
):
graphs_filter.append(graph)
else:
warnings.warn(
f"the {idx}-th graph in self.graphs is "
"filtered out as it contains "
f"{sum(graph.num_nodes().values())} nodes,"
" which is less than "
"self.minimum_node_per_graph: "
f"{self.minimum_node_per_graph}."
)
else:
if graph.num_nodes >= self.minimum_node_per_graph:
graphs_filter.append(graph)
else:
warnings.warn(
f"the {idx}-th graph in self.graphs is "
"filtered out as it contains "
f"{graph.num_nodes} nodes,"
" which is less than "
"self.minimum_node_per_graph: "
f"{self.minimum_node_per_graph}."
)
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
self.graphs = graphs_filter
# update graph in self.graphs with appropriate custom
# split component
for graph in self.graphs:
if not hasattr(graph, "_custom_update_flag"):
# assign task to graph
graph.task = self.task
# custom support
if isinstance(graph, Graph):
if isinstance(graph, HeteroGraph):
mapping = {
x: x
for x in range(sum(graph.num_nodes().values()))
}
else:
mapping = {x: x for x in range(graph.num_nodes)}
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
graph._custom_update(mapping)
# update graph in self.graphs with negative_label_val
if self.task == "link_pred":
if self.negative_label_val is None:
negative_label_val = 0
for graph in self.graphs:
if (
hasattr(graph, "edge_label")
and (graph.edge_label is not None)
and (self.task == "link_pred")
):
if isinstance(graph, Graph):
if isinstance(graph, HeteroGraph):
for message_type in graph.edge_label:
negative_label_val = max(
negative_label_val,
torch.max(
graph.edge_label[message_type]
).item() + 1
)
else:
negative_label_val = max(
negative_label_val,
torch.max(graph.edge_label).item() + 1
)
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
self.negative_label_val = negative_label_val
for graph in self.graphs:
graph.negative_label_val = (
copy.deepcopy(negative_label_val)
)
self._update_tensor_negative_edges()
self._custom_mode_update()
self._reset_cache()
def _update_tensor_negative_edges(self):
r"""
Create negative edges and labels for tensor backend link_pred case.
"""
if self.task != "link_pred":
return
if not all([graph.G is None for graph in self.graphs]):
return
any_negative_edges = any(
["negative_edge" in graph.keys for graph in self.graphs]
)
all_negative_edges = all(
["negative_edge" in graph.keys for graph in self.graphs]
)
if (not all_negative_edges) and any_negative_edges:
raise ValueError(
"either all graphs have negative edges or no graphs have "
"negative edges."
)
else:
self._is_tensor = True
for graph in self.graphs:
graph._edge_label = copy.deepcopy(graph.edge_label)
graph._edge_label_index = copy.deepcopy(graph.edge_label_index)
if all_negative_edges:
graph._custom_create_neg_sampling(
self.edge_negative_sampling_ratio, resample=False
)
else:
graph._create_neg_sampling(
self.edge_negative_sampling_ratio, resample=False
)
def __len__(self) -> int:
r"""
Returns:
int: The number of graph in graphs.
"""
if self.graphs is None:
return len(self.generator)
else:
return len(self.graphs)
@property
def num_node_features(self) -> int:
r"""
Returns the node feature dimension.
Returns:
int: The node feature dimension for the graphs
in the dataset.
"""
return self._graph_example.num_node_features
@property
def num_node_labels(self) -> int:
r"""
Returns the number of node labels.
Returns:
int: The number of node labels for the graphs
in the dataset.
"""
if self._num_node_labels is None:
if self.graphs is None:
self._num_node_labels = self.generator.num_node_labels
else:
unique_node_labels = torch.LongTensor([])
for graph in self.graphs:
unique_node_labels = torch.cat([
unique_node_labels, graph.get_num_labels("node_label")
])
self._num_node_labels = torch.unique(
unique_node_labels
).shape[0]
return self._num_node_labels
@property
def num_nodes(self) -> List[int]:
r"""
Return the number of nodes for the graphs in the dataset.
Returns:
list: A list of number of nodes for the graphs
in the dataset.
"""
if self._num_nodes is None:
if self.graphs is None:
self._num_nodes = self.generator.num_nodes
else:
self._num_nodes = (
[graph.num_nodes for graph in self.graphs]
)
return self._num_nodes
@property
def num_edge_features(self) -> int:
r"""
Returns the edge feature dimension.
Returns:
int: The edge feature dimension for the graphs
in the dataset.
"""
return self._graph_example.num_edge_features
@property
def num_edge_labels(self) -> int:
r"""
Returns the number of edge labels.
Returns:
int: The number of edge labels for the graphs
in the dataset.
"""
if self._num_edge_labels is None:
if self.graphs is None:
self._num_edge_labels = self.generator.num_edge_labels
else:
unique_edge_labels = torch.LongTensor([])
for graph in self.graphs:
unique_edge_labels = torch.cat([
unique_edge_labels, graph.get_num_labels("edge_label")
])
self._num_edge_labels = torch.unique(
unique_edge_labels
).shape[0]
return self._num_edge_labels
@property
def num_edges(self) -> List[int]:
r"""
Return the number of edges for the graphs in the dataset.
Returns:
list: A list of number of edges for the graphs
in the dataset.
"""
if self._num_edges is None:
if self.graphs is None:
self._num_edges = self.generator.num_edges
else:
self._num_edges = (
[graph.num_edges for graph in self.graphs]
)
return self._num_edges
@property
def num_graph_features(self) -> int:
r"""
Returns the graph feature dimension.
Returns:
int: The graph feature dimension for the graphs
in the dataset.
"""
return self._graph_example.num_graph_features
@property
def num_graph_labels(self) -> int:
r"""
Returns the number of graph labels.
Returns:
int: The number of graph labels for the graphs
in the dataset.
"""
if self._num_graph_labels is None:
if self.graphs is None:
self._num_graph_labels = self.generator.num_graph_labels
else:
unique_graph_labels = torch.LongTensor([])
for graph in self.graphs:
unique_graph_labels = torch.cat([
unique_graph_labels,
graph.get_num_labels("graph_label")
])
self._num_graph_labels = torch.unique(
unique_graph_labels
).shape[0]
return self._num_graph_labels
@property
def num_labels(self) -> int:
r"""
A General wrapper that returns the number of labels depending on
the task.
Returns:
int: The number of labels, depending on the task.
"""
if self.task == "node":
return self.num_node_labels
elif self.task == "edge" or self.task == "link_pred":
return self.num_edge_labels
elif self.task == "graph":
return self.num_graph_labels
else:
raise ValueError(f"Task {self.task} not supported")
[docs] def num_dims_dict(self) -> Dict[str, int]:
r"""
Dimensions of all fields.
Returns:
dict: Dimensions of all fields. For example, if
graphs have two attributes the `node_feature`
and the `graph_label`. The returned dictionary will
have two keys, `node_feature` and `graph_label`, and
two values, node feature dimension and graph label
dimension.
"""
dim_dict = {}
for key in self._graph_example.keys:
tensor = self._graph_example[key]
if not torch.is_tensor(tensor):
continue
if tensor.ndim == 1:
dim_dict[key] = 1
elif tensor.ndim == 2:
dim_dict[key] = tensor.size()[-1]
else:
raise ValueError(f"Dimension of tensor {key} exceeds 2.")
return dim_dict
def _custom_mode_update(self):
r"""
Update self.general_splits_mode, self.disjoint_split_mode &
self.negative_edges_mode to indicate whether we are working on custom
support for:
(1) general transductive or inductive custom split
or (2) disjoint train custom split in disjoint link_pred task
or (3) custom negative edges in link_pred task
"""
custom_keys = ["general_splits", "disjoint_split", "negative_edges"]
for custom_key in custom_keys:
self[f"{custom_key}_mode"] = "random"
# make sure custom appeared in all graphs or no custom appeared
# in any graphs
custom_in_graphs = all(
(graph.custom is not None) for graph in self.graphs
)
custom_not_in_graphs = all(
(graph.custom is None) for graph in self.graphs
)
if not custom_in_graphs and not custom_not_in_graphs:
raise ValueError(
"custom needs to be in all graphs or not in any graphs"
)
if custom_in_graphs:
for custom_key in custom_keys:
custom_key_in_custom = all(
(custom_key in graph.custom)
for graph in self.graphs
)
custom_key_not_in_custom = all(
(custom_key not in graph.custom) for graph in self.graphs
)
if not custom_key_in_custom and not custom_key_not_in_custom:
raise ValueError(
f"{custom_key} needs to be in all `graph.custom`s or "
"not in any `graph.custom`s"
)
if custom_key_in_custom:
self[f"{custom_key}_mode"] = "custom"
# custom inductive splits
if self.custom_split_graphs is not None:
self.general_splits_mode = "custom"
def _split_transductive(
self,
split_ratio: List[float],
split_types: List[str] = None,
shuffle: bool = True
) -> List[Graph]:
r"""
Split the dataset assuming training process is transductive.
Args:
split_ratio: number of data splitted into train, validation
(and test) set.
Returns:
list: A list of 3 (2) lists of :class:`deepsnap.graph.Graph`
object corresponding to train, validation (and test) set.
"""
split_graphs = []
for graph in self.graphs:
if self.general_splits_mode == "custom":
split_graph = graph._custom_split(
task=self.task
)
elif self.general_splits_mode == "random":
if isinstance(graph, Graph):
if isinstance(graph, HeteroGraph):
split_graph = graph.split(
task=self.task,
split_types=split_types,
split_ratio=split_ratio,
edge_split_mode=self.edge_split_mode,
shuffle=shuffle
)
else:
split_graph = graph.split(
self.task, split_ratio, shuffle=shuffle
)
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
split_graphs.append(split_graph)
split_graphs = list(map(list, zip(*split_graphs)))
if self.disjoint_split_mode == "custom":
# resample_disjoint when in disjoint split custom mode
# would override the custom disjoint split edges
self.resample_disjoint = False
for i, graph in enumerate(split_graphs[0]):
if (
self.task == "link_pred"
and self.edge_train_mode == "disjoint"
):
graph = graph._custom_split_link_pred_disjoint()
split_graphs[0][i] = graph
elif self.disjoint_split_mode == "random":
for i, graph in enumerate(split_graphs[0]):
if (
self.task == "link_pred"
and self.edge_train_mode == "disjoint"
):
if isinstance(graph, Graph):
# store the original edge_label
graph_edge_label = None
if (
self.resample_disjoint
and hasattr(graph, "edge_label")
):
graph_edge_label = graph.edge_label
if isinstance(graph, HeteroGraph):
graph = graph.split_link_pred(
split_types=split_types,
split_ratio=self.edge_message_ratio,
edge_split_mode=self.edge_split_mode
)[1]
else:
graph = graph.split_link_pred(
self.edge_message_ratio
)[1]
graph._is_train = True
split_graphs[0][i] = graph
# save the original edge_label
if graph_edge_label is not None:
graph._edge_label = copy.deepcopy(graph_edge_label)
else:
graph._edge_label = None
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
# list of num_splits datasets
# (e.g. [train dataset, val dataset, test dataset])
dataset_return = []
if self.negative_edges_mode == "random":
for x in split_graphs:
dataset_current = copy.copy(self)
dataset_current.graphs = x
if self.task == "link_pred":
for graph_temp in dataset_current.graphs:
if isinstance(graph_temp, Graph):
if isinstance(graph_temp, HeteroGraph):
graph_temp._create_neg_sampling(
negative_sampling_ratio=(
self.edge_negative_sampling_ratio
),
split_types=split_types
)
else:
graph_temp._create_neg_sampling(
self.edge_negative_sampling_ratio
)
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
dataset_return.append(dataset_current)
elif self.negative_edges_mode == "custom":
for i, x in enumerate(split_graphs):
dataset_current = copy.copy(self)
dataset_current.graphs = x
if self.task == "link_pred":
for j, graph_temp in enumerate(dataset_current.graphs):
if isinstance(graph_temp, Graph):
graph_temp.negative_edge = (
graph_temp.negative_edges[i]
)
if isinstance(graph_temp, HeteroGraph):
graph_temp._custom_create_neg_sampling(
self.edge_negative_sampling_ratio,
split_types=split_types
)
else:
graph_temp._custom_create_neg_sampling(
self.edge_negative_sampling_ratio
)
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
dataset_return.append(dataset_current)
# resample negatives for train split (only for link prediction)
dataset_return[0].resample_negatives = True
return dataset_return
def _split_inductive(
self,
split_ratio: List[float],
split_types: List[str] = None,
shuffle: bool = True
) -> List[Graph]:
r"""
Split the dataset assuming training process is inductive.
Args:
split_ratio: number of data splitted into train, validation
(and test) set.
Returns:
List[Graph]: a list of 3 (2) lists of graph object corresponding
to train, validation (and test) set.
"""
if self.general_splits_mode == "custom":
split_graphs = self.custom_split_graphs
elif self.general_splits_mode == "random":
num_graphs = len(self.graphs)
if num_graphs < len(split_ratio):
raise ValueError(
"in _split_inductive num of graphs are smaller than the "
"number of splitted parts"
)
if shuffle:
self._shuffle()
# a list of num_splits list of graphs
# (e.g. [train graphs, val graphs, test graphs])
split_graphs = []
# If the `default split` policy would result in empty splited
# graphs, `secure split` policy would be used instead
split_empty_flag = False
split_offset = 0
# perform `default split`
for i, split_ratio_i in enumerate(split_ratio):
if i != len(split_ratio) - 1:
num_split_i = int(split_ratio_i * num_graphs)
graphs_split_i = (
self.graphs[split_offset:split_offset + num_split_i]
)
split_offset += num_split_i
else:
graphs_split_i = self.graphs[split_offset:]
if len(graphs_split_i) == 0:
split_empty_flag = True
split_offset = 0
split_graphs = []
break
split_graphs.append(graphs_split_i)
if split_empty_flag:
# perform `secure split` s.t. guarantees all splitted graph
# list contains at least one graph.
for i, split_ratio_i in enumerate(split_ratio):
if i != len(split_ratio) - 1:
num_split_i = (
1 +
int(
split_ratio_i
* (num_graphs - len(split_ratio))
)
)
graphs_split_i = (
self.graphs[
split_offset:split_offset + num_split_i
]
)
split_offset += num_split_i
else:
graphs_split_i = self.graphs[split_offset:]
split_graphs.append(graphs_split_i)
# create objectives for link_pred task
if self.task == "link_pred":
# if disjoint, this will split all graph's edges into 2:
# message passing and objective edges
# which is returned by the [1] of the split graphs
if self.edge_train_mode == "disjoint":
split_start = 0
# in all mode, train graph has all edges used for both
# message passing and objective.
elif self.edge_train_mode == "all":
split_start = 1
for i in range(split_start, len(split_graphs)):
for j in range(len(split_graphs[i])):
graph_temp = split_graphs[i][j]
if isinstance(graph_temp, Graph):
# store the original edge_label
graph_edge_label = None
if (
self.resample_disjoint
and (i == 0)
and hasattr(graph_temp, "edge_label")
):
graph_edge_label = graph_temp.edge_label
if isinstance(graph_temp, HeteroGraph):
graph_temp = (
graph_temp.split_link_pred(
split_types,
self.edge_message_ratio,
self.edge_split_mode
)[1]
)
else:
graph_temp = (
graph_temp.split_link_pred(
self.edge_message_ratio
)[1]
)
# save the original edge_label
if graph_edge_label is not None:
graph_temp._edge_label = (
copy.deepcopy(graph_edge_label)
)
else:
graph_temp._edge_label = None
# set is_train flag
if i == 0:
graph_temp._is_train = True
split_graphs[i][j] = graph_temp
else:
raise TypeError(
"element in self.graphs of unexpected type."
)
# list of num_splits datasets
dataset_return = []
for graphs in split_graphs:
dataset_current = copy.copy(self)
dataset_current.graphs = graphs
if self.task == "link_pred":
for graph_temp in dataset_current.graphs:
if isinstance(graph_temp, Graph):
if isinstance(graph_temp, HeteroGraph):
graph_temp._create_neg_sampling(
negative_sampling_ratio=(
self.edge_negative_sampling_ratio
),
split_types=split_types
)
else:
graph_temp._create_neg_sampling(
self.edge_negative_sampling_ratio
)
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
dataset_return.append(dataset_current)
# resample negatives for train split (only for link prediction)
dataset_return[0].resample_negatives = True
return dataset_return
[docs] def split(
self,
transductive: bool = True,
split_ratio: List[float] = None,
split_types: Union[str, List[str]] = None,
shuffle: bool = True
) -> List[Graph]:
r"""
Split the dataset into train, validation (and test) sets.
Args:
transductive (bool): Whether the learning is transductive
(`True`) or inductive (`False`). Inductive split is
always used for the graph-level task, :obj:`self.task`
equals to `graph`.
split_ratio (list): A list of ratios such as
`[train_ratio, validation_ratio, test_ratio]`.
split_types (str or list): Types splitted on. Default is `None`.
shuffle (bool): Whether to shuffle data for the splitting.
Returns:
list: A list of 3 (2) :class:`deepsnap.dataset.GraphDataset`
objects corresponding to the train, validation (and test) sets.
"""
if self.graphs is None:
raise RuntimeError(
"Split is not supported for on-the-fly dataset. "
"Construct different on-the-fly datasets for train, val "
"and test. Or perform split at batch level."
)
if split_ratio is None:
split_ratio = [0.8, 0.1, 0.1]
if not isinstance(split_ratio, list):
raise TypeError("Split ratio must be a list.")
if len(split_ratio) > 3:
raise ValueError(
"Split ratio must contain less than or equal to three values."
)
if not math.isclose(sum(split_ratio), 1.0):
raise ValueError("Split ratio must sum up to 1.")
if not all(
isinstance(split_ratio_i, float)
for split_ratio_i in split_ratio
):
raise TypeError("Split ratio must contain all floats.")
if not all(split_ratio_i > 0 for split_ratio_i in split_ratio):
raise ValueError("Split ratio must contain all positivevalues.")
# store the most recent split types
self._split_types = split_types
# check self._is_tensor
if self._is_tensor:
for graph in self.graphs:
graph.edge_label_index = graph._edge_label_index
graph.edge_label = graph._edge_label
# list of num_splits datasets
dataset_return = []
if transductive:
if self.task == "graph":
raise ValueError(
"in transductive mode, self.task is graph does not "
"make sense."
)
dataset_return = (
self._split_transductive(
split_ratio, split_types, shuffle=shuffle
)
)
else:
dataset_return = (
self._split_inductive(
split_ratio,
split_types,
shuffle=shuffle
)
)
return dataset_return
[docs] def resample_disjoint(self):
r"""
Resample splits of the message passing and supervision edges in the
`disjoint` mode.
.. note::
If :meth:`apply_transform` (on the message passing graph)
was used before this resampling, it needs to be
re-applied after resampling, to update some of the (supervision)
edges that were in the objectives.
"""
if self.graphs is None:
raise RuntimeError(
"Resampling disjoint is not needed for on-the-fly dataset. "
"Split the on-the-fly data as the batch arrives."
)
graphs = []
for graph in self.graphs:
graphs.append(graph.resample_disjoint(self.edge_message_ratio))
self.graphs = graphs
def _reset_cache(self):
r"""
Resets internal cache for graph examples, num_node_labels etc.
"""
self._num_node_labels = None
self._num_nodes = None
self._num_edge_labels = None
self._num_edges = None
self._num_graph_labels = None
# TODO: consider the heterogeneous graph case
if self.graphs is None:
self._graph_example = self.generator.generate()
if not isinstance(self._graph_example, Graph):
self._graph_example = Graph(self._graph_example)
else:
self._graph_example = self.graphs[0]
[docs] def filter(self, filter_fn, deep_copy: bool = False, **kwargs):
r"""
Filter the graphs in the dataset. Discarding a graph `G`
when `filter_fn(G)` is `False`. :meth:`apply_transform` is an
analog of the Python `map` function, while :meth:`filter`
is an analog of the Python `filter` function.
Args:
filter_fn: User-defined filter function that returns `True`
(keep) or `False` (discard) the graph object in
the dataset.
deep_copy: If `True`, all graphs will be deepcopied and
then fed into the :meth:`filter` function.
**kwargs: Parameters used in the :meth:`filter` function.
Returns:
:class:`GraphDataset`: A new :class:`GraphDataset` object with
graphs filtered.
"""
# currently does not support filter for on-the-fly dataset
if self.graphs is None:
raise ValueError(
"On-the-fly datasets do not support transform."
"Filter can be done at the batch level."
)
new_dataset = copy.copy(self)
new_dataset.graphs = [
graph for graph in self.graphs if filter_fn(graph, **kwargs)]
# update example graph used for num_node_features etc.
new_dataset._reset_cache()
return new_dataset
[docs] def to(self, device):
r"""
Transfer the graphs in the dataset to specified device.
Args:
device (str): Specified device name, such as `cpu` or
`cuda`.
"""
if self.graphs is None:
self.otf_device = device
else:
for graph in self.graphs:
graph.to(device)
def _shuffle(self):
r"""
shuffle Graph object in graphs.
"""
if self.graphs is not None:
random.shuffle(self.graphs)
[docs] @staticmethod
def pyg_to_graphs(
dataset,
verbose: bool = False,
fixed_split: bool = False,
tensor_backend: bool = False,
netlib=None
) -> List[Graph]:
r"""
Transform a :class:`torch_geometric.data.Dataset` object to a
list of :class:`deepsnap.grpah.Graph` objects.
Args:
dataset (:class:`torch_geometric.data.Dataset`): A
:class:`torch_geometric.data.Dataset` object that will be
transformed to a list of :class:`deepsnap.grpah.Graph`
objects.
verbose (bool): Whether to print information such as warnings.
fixed_split (bool): Whether to load the fixed data split from
the original PyTorch Geometric dataset.
tensor_backend (bool): `True` will use pure tensors for graphs.
netlib (types.ModuleType, optional): The graph backend module.
Currently DeepSNAP supports the NetworkX and SnapX (for
SnapX only the undirected homogeneous graph) as the graph
backend. Default graph backend is the NetworkX.
Returns:
list: A list of :class:`deepsnap.graph.Graph` objects.
"""
if fixed_split:
graphs = [
Graph.pyg_to_graph(
data, verbose=verbose, fixed_split=True,
tensor_backend=tensor_backend, netlib=netlib
)
for data in dataset
]
graphs_split = [[graph] for graph in graphs[0]]
return graphs_split
else:
return [
Graph.pyg_to_graph(
data, verbose=verbose,
tensor_backend=tensor_backend,
netlib=netlib
)
for data in dataset
]
def __getitem__(self, idx: int) -> Union[Graph, List[Graph]]:
r"""
Takes in an integer (or a list of integers)
returns a single Graph object (a subset of graphs).
Args:
idx: index to be selected from graphs.
Returns:
Union[:class:`deepsnap.graph.Graph`,
List[:class:`deepsnap.graph.Graph`]]: A single
:class:`deepsnap.graph.Graph` object or subset
of :class:`deepsnap.graph.Graph` objects.
"""
# TODO: add the hetero graph equivalent of these functions
if self.graphs is None:
graph = self.generator.generate()
if not isinstance(graph, Graph):
graph = Graph(graph)
# generated an networkx graph
if self.otf_device is not None:
graph.to(self.otf_device)
elif isinstance(idx, int):
graph = self.graphs[idx]
else:
# sliceing of dataset
dataset = self._index_select(idx)
return dataset
# resample disjoint training data only when the task is
# disjoint link_pred and self.resample_disjoint is set to True
if (
self.task == "link_pred"
and self.edge_train_mode == "disjoint"
and self.resample_disjoint
and graph._is_train
):
if not hasattr(graph, "resample_disjoint_period"):
graph.resample_disjoint_period = self.resample_disjoint_period
if isinstance(graph, Graph):
if isinstance(graph, HeteroGraph):
graph = graph.resample_disjoint(
split_types=self._split_types,
message_ratio=self.edge_message_ratio
)
else:
graph = graph.resample_disjoint(self.edge_message_ratio)
else:
raise TypeError(
"element in self.graphs of unexpected type."
)
if self.task == "link_pred" and self.resample_negatives:
resample_negative_flag = True
# after graph just resampled disjoint training data
# graph.edge_label is reset to original state,
# the negative sampling process needs to utilize
# resample=False mode.
if (
hasattr(graph, "_resample_disjoint_flag")
and graph._resample_disjoint_flag
):
resample_negative_flag = False
# resample negative examples
if isinstance(graph, Graph):
if isinstance(graph, HeteroGraph):
if self.negative_edges_mode == "random":
graph._create_neg_sampling(
self.edge_negative_sampling_ratio,
split_types=self._split_types,
resample=resample_negative_flag
)
elif self.negative_edges_mode == "custom":
graph._custom_create_neg_sampling(
self.edge_negative_sampling_ratio,
split_types=self._split_types,
resample=resample_negative_flag
)
else:
if self.negative_edges_mode == "random":
graph._create_neg_sampling(
self.edge_negative_sampling_ratio,
resample=resample_negative_flag
)
elif self.negative_edges_mode == "custom":
graph._custom_create_neg_sampling(
self.edge_negative_sampling_ratio,
resample=resample_negative_flag
)
else:
raise TypeError(
"element in self.graphs of unexpected type."
)
if self.graphs is not None and isinstance(idx, int):
self.graphs[idx] = graph
return graph
def __setitem__(self, key: str, value):
"""Sets the attribute :obj:`key` to :obj:`value`."""
setattr(self, key, value)
def _index_select(self, idx: int) -> List[Graph]:
r"""
Takes in a list of integers, returns a subset of graphs
corresponding to the list of integers.
Args:
idx: index to be selected from graphs.
Returns:
List[Graph]: a single Graph object or subset of graphs.
"""
if self.graphs is None:
# _index_select is only called when self.graphs is not None
raise NotImplementedError(
"Index select is not available for on-the-fly dataset."
)
if isinstance(idx, slice):
dataset = copy.copy(self)
dataset.graphs = self.graphs[idx]
elif torch.is_tensor(idx):
if (
idx.dtype == torch.long
or idx.dtype == torch.int
):
dataset = self._index_select(idx.tolist())
elif idx.dtype == torch.bool:
dataset = self._index_select(idx.nonzero().flatten().tolist())
else:
raise TypeError(
f"your index type is {idx.dtype}, only tensor of type "
"torch.long, torch.int or torch.bool are accepted."
)
elif isinstance(idx, list) or isinstance(idx, tuple):
dataset = copy.copy(self)
dataset.graphs = [self.graphs[x] for x in idx]
else:
raise IndexError(
"Only integers, slices (`:`), list, tuples, and long or bool "
f"tensors are valid indices (got {type(idx).__name__})."
)
return dataset
def __repr__(self) -> str: # pragma: no cover
descriptor = (
len(self) if self.graphs is not None else self.generator.__class__
)
return f"{self.__class__.__name__}({descriptor})"