import torch
from deepsnap.graph import Graph
from deepsnap.hetero_graph import HeteroGraph
from typing import (
Callable,
Dict,
List
)
[docs]class Batch(Graph):
r"""
A plain old python object modeling a batch of
:class:`deepsnap.graph.Graph` objects as one big (disconnected) graph,
with :class:`torch_geometric.data.Data` being the
base class that all its methods can also be used here.
In addition, graphs can be reconstructed via the assignment vector
:obj:`batch`, which maps each node to its respective graph identifier.
.. note::
For more detailed use of :class:`deepsnap.batch.Batch`, see the `examples
<https://github.com/snap-stanford/deepsnap/tree/master/examples>`_ folder.
"""
def __init__(self, batch=None, **kwargs):
super(Batch, self).__init__(**kwargs)
self.batch = batch
self.__data_class__ = Graph
self.__slices__ = None
@staticmethod
def collate(follow_batch=[], transform=None, **kwargs):
return lambda batch: Batch.from_data_list(
batch, follow_batch, transform, **kwargs
)
[docs] @staticmethod
def from_data_list(
data_list: List[Graph],
follow_batch: List = None,
transform: Callable = None,
**kwargs
):
r"""
Constructs A :class:`deepsnap.batch.Batch` object from a python list
holding :class:`torch_geometric.data.Data` objects.
The assignment vector :obj:`batch` is created on the fly.
Additionally, creates assignment batch vectors for each key in
:obj:`follow_batch`.
Args:
data_list (list): A list of :class:`deepsnap.graph.Graph` objects.
follow_batch (list): Creates assignment batch vectors
for each key.
transform (callable): If it is not `None`, apply transform
when batching.
**kwargs: Other parameters.
"""
if follow_batch is None:
follow_batch = []
if transform is not None:
data_list = [
data.apply_transform(
transform,
deep_copy=True,
**kwargs,
)
for data in data_list
]
keys = [set(data.keys) for data in data_list]
keys = list(set.union(*keys))
assert "batch" not in keys
batch, cumsum = Batch._init_batch_fields(keys, follow_batch)
batch.__data_class__ = data_list[0].__class__
batch.batch = []
for i, data in enumerate(data_list):
# Note: in heterogeneous graph, __inc__ logic is different
Batch._collate_dict(
data, cumsum,
batch.__slices__, batch,
data, follow_batch, i=i
)
if isinstance(data, Graph):
if isinstance(data, HeteroGraph):
num_nodes = sum(data.num_nodes().values())
else:
num_nodes = data.num_nodes
else:
raise TypeError(
"element in self.graphs of unexpected type"
)
if num_nodes is not None:
item = torch.full((num_nodes, ), i, dtype=torch.long)
batch.batch.append(item)
if num_nodes is None:
batch.batch = None
Batch._dict_list_to_tensor(batch, data_list[0])
return batch.contiguous()
@staticmethod
def _init_batch_fields(keys, follow_batch):
batch = Batch()
batch.__slices__ = {key: [0] for key in keys}
for key in keys:
batch[key] = []
for key in follow_batch:
batch[f"{key}_batch"] = []
cumsum = {key: 0 for key in keys}
return batch, cumsum
@staticmethod
def _collate_dict(
curr_dict,
cumsum: Dict[str, int],
slices,
batched_dict,
graph,
follow_batch,
i=None
):
r""" Called in from_data_list to collate a dictionary.
This can also be applied to Graph object, since it has support for
keys and __getitem__().
Args:
curr_dict: current dictionary to be added to the
collated dictionary.
cumsum: cumulative sum to be used for indexing.
slices: a dictionary of the same structure as batched_dict,
slices[key] indicates the indices to slice batch[key] into
tensors for all graphs in the batch.
batched_dict: the batched dictionary of the same structure
as curr_dict. But all graph data are batched together.
"""
if isinstance(curr_dict, dict):
keys = curr_dict.keys()
else:
keys = curr_dict.keys
for key in keys:
item = curr_dict[key]
if isinstance(item, dict):
# recursively collate every key in the dictionary
if isinstance(batched_dict[key], list):
# nested dictionary not initialized yet
assert len(batched_dict[key]) == 0
# initialize the nested dictionary for batch
cumsum[key] = {inner_key: 0 for inner_key in item.keys()}
slices[key] = {inner_key: [0] for inner_key in item.keys()}
batched_dict[key] = {}
for inner_key in item.keys():
batched_dict[key][inner_key] = []
for inner_key in follow_batch:
batched_dict[key][f"{key}_batch"] = []
Batch._collate_dict(
item, cumsum[key],
slices[key], batched_dict[key],
graph, follow_batch, i=i
)
continue
if torch.is_tensor(item) and item.dtype != torch.bool:
item = item + cumsum[key]
if torch.is_tensor(item):
size = item.size(graph.__cat_dim__(key, curr_dict[key]))
else:
size = 1
slices[key].append(size + slices[key][-1])
cumsum[key] = cumsum[key] + graph.__inc__(key, item)
batched_dict[key].append(item)
if key in follow_batch:
item = torch.full((size, ), i, dtype=torch.long)
batched_dict[f"{key}_batch"].append(item)
@staticmethod
def _dict_list_to_tensor(dict_of_list, graph):
r"""Convert a dict/Graph with list as values to a dict/Graph with
concatenated/stacked tensor as values.
"""
if isinstance(dict_of_list, dict):
keys = dict_of_list.keys()
else:
keys = dict_of_list.keys
for key in keys:
if isinstance(dict_of_list[key], dict):
# recursively convert the dictionary of list to dict of tensor
Batch._dict_list_to_tensor(dict_of_list[key], graph)
continue
item = dict_of_list[key][0]
if torch.is_tensor(item):
if (
Graph._is_graph_attribute(key)
and item.ndim == 1
and (not item.dtype == torch.long)
and "feature" in key
):
# special consideration: 1D tensor for graph
# attribute (classification)
# named as: "graph_xx_feature"
# batch by stacking the first dim
dict_of_list[key] = torch.stack(
dict_of_list[key],
dim=0
)
else:
# concat at the __cat_dim__
dict_of_list[key] = torch.cat(
dict_of_list[key],
dim=graph.__cat_dim__(key, item)
)
elif isinstance(item, (float, int)):
dict_of_list[key] = torch.tensor(dict_of_list[key])
[docs] def to_data_list(self):
r"""
Reconstructs the list of :class:`torch_geometric.data.Data` objects
from the batch object.
The batch object must have been created via :meth:`from_data_list` in
order to be able to reconstruct the initial objects.
"""
if self.__slices__ is None:
raise RuntimeError(
"Cannot reconstruct data list from batch because the "
"batch object was not created using Batch.from_data_list()"
)
keys = [key for key in self.keys if key[-5:] != "batch"]
cumsum = {key: 0 for key in keys}
data_list = []
for i in range(len(self.__slices__[keys[0]]) - 1):
# i: from 0 up to num graphs in the batch
data = self.__data_class__()
self._reconstruct_dict(
i, keys, data, cumsum, self.__slices__, self, data
)
data_list.append(data)
return data_list
def _reconstruct_dict(
self, graph_idx: int, keys, data_dict,
cumsum: Dict[str, int], slices, batched_dict, graph):
for key in keys:
if isinstance(batched_dict[key], dict):
# recursively unbatch the dict
data_dict[key] = {}
inner_keys = [
inner_key
for inner_key in batched_dict[key].keys()
if inner_key[-5:] != "batch"
]
inner_cumsum = {inner_key: 0 for inner_key in inner_keys}
inner_slices = slices[key]
self._reconstruct_dict(
graph_idx, inner_keys,
data_dict[key], inner_cumsum,
inner_slices, batched_dict[key], graph
)
continue
if torch.is_tensor(batched_dict[key]):
data_dict[key] = batched_dict[key].narrow(
graph.__cat_dim__(key, batched_dict[key]),
slices[key][graph_idx],
slices[key][graph_idx + 1] - slices[key][graph_idx]
)
if batched_dict[key].dtype != torch.bool:
data_dict[key] = data_dict[key] - cumsum[key]
else:
data_dict[key] = (
batched_dict[key][
slices[key][graph_idx]:slices[key][graph_idx + 1]
]
)
cumsum[key] = cumsum[key] + graph.__inc__(key, data_dict[key])
@property
def num_graphs(self) -> int:
r"""
Returns the number of graphs in the batch.
Returns:
int: The number of graphs in the batch.
"""
return self.batch[-1].item() + 1