Tutorial 14: Data Handling in PyG (Part 1)

[ ]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.12.1+cu113
     |████████████████████████████████| 7.9 MB 3.7 MB/s
     |████████████████████████████████| 3.5 MB 7.4 MB/s
  Building wheel for torch-geometric (setup.py) ... done
[1]:
import numpy as np
import torch
import torch_geometric.datasets as datasets
import torch_geometric.data as data
import torch_geometric.transforms as transforms
import networkx as nx
from torch_geometric.utils.convert import to_networkx

Data

Let’s create a dummy graph

[2]:
embeddings = torch.rand((100, 16), dtype=torch.float)
[3]:
rows = np.random.choice(100, 500)
cols = np.random.choice(100, 500)
edges = torch.tensor([rows, cols])
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/ipykernel_launcher.py:3: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  /opt/conda/conda-bld/pytorch_1646755953518/work/torch/csrc/utils/tensor_new.cpp:210.)
  This is separate from the ipykernel package so we can avoid doing imports until
[4]:
edges_attr = np.random.choice(3,500)
[5]:
ys = torch.rand((100)).round().long()

Convert the graph information into a PyG Data object

[6]:
graph = data.Data(x=embeddings, edge_index=edges, edge_attr=edges_attr, y=ys)
[7]:
graph
[7]:
Data(x=[100, 16], edge_index=[2, 500], edge_attr=[500], y=[100])

Let’s visualize the information contained in the data object

[8]:
for prop in graph:
    print(prop)
('x', tensor([[0.7074, 0.6998, 0.3472,  ..., 0.4214, 0.9162, 0.8406],
        [0.1817, 0.2383, 0.0654,  ..., 0.8730, 0.5678, 0.2866],
        [0.8566, 0.6106, 0.7784,  ..., 0.4809, 0.3192, 0.6875],
        ...,
        [0.0338, 0.8082, 0.4665,  ..., 0.6577, 0.8019, 0.8696],
        [0.0920, 0.6424, 0.3201,  ..., 0.9704, 0.7353, 0.1449],
        [0.6611, 0.6635, 0.0473,  ..., 0.1450, 0.7180, 0.7524]]))
('edge_index', tensor([[23,  6, 58, 82, 62, 20, 48, 99,  0,  4, 96, 79, 55, 15, 94, 34, 39, 99,
         66, 37, 23, 67, 68, 99, 83, 96,  5, 58, 66, 79, 97, 35, 47, 67, 85, 65,
         80, 24, 31, 22, 15, 45, 61, 28, 71, 71, 50, 67, 28, 79, 63, 52,  3, 86,
          5, 17, 68, 88, 20, 52, 17, 44, 69, 12, 62, 41, 44, 56, 32, 39,  8, 35,
         68,  7, 48, 18, 39,  4, 79, 92,  1, 20, 29, 82, 58, 20,  1, 99, 39, 30,
          9, 61, 86, 93, 37, 29, 88, 40,  0, 72, 12, 87, 70, 17,  3,  7, 64, 46,
         55, 35, 26, 13, 73, 81, 36, 49, 50, 47, 31, 79,  3, 68, 31,  1, 27, 53,
         72, 76, 71, 55, 69, 62, 14, 82, 53, 64,  6, 89, 64, 27,  0, 26, 86, 16,
         75, 40, 76, 53, 75, 41, 16,  9, 71, 29, 72, 66, 61, 29, 95, 85, 10, 13,
         58, 45,  2,  9, 75, 10, 41, 70, 34, 16, 20, 82, 17, 89, 80, 31, 37, 77,
         69, 18, 71, 33, 88, 52, 30, 73, 97, 19, 20, 17, 83, 70, 43, 79, 71, 61,
         95, 27, 37, 88, 78, 38, 65, 85, 52, 47,  1, 21, 96, 78, 48, 25, 22, 23,
         66, 60, 52, 84, 50, 49, 92, 57, 88, 24, 66, 75, 73, 39, 59, 13, 25, 36,
         50, 19, 14, 84, 65, 49, 91, 75, 79, 29, 16, 60, 64, 25, 41, 84, 95, 75,
         51, 57, 64, 32, 14, 74, 34, 25, 36, 51, 52, 96, 11, 56, 57, 57, 83, 97,
         48, 89,  4, 62, 83, 27, 71, 24, 41, 47,  4, 61, 79, 67, 87, 58, 75, 18,
         56, 52, 15, 53, 77,  1,  3, 27, 61, 78, 26, 90, 97, 29, 34, 76, 77, 39,
         14, 23, 68, 48, 47, 34, 67, 18, 86, 69, 59,  0, 79, 78, 55, 97, 10, 51,
         86, 60, 12,  8,  4, 49, 57, 93, 70, 55, 70,  2, 86, 31, 10, 20,  4, 26,
          8, 42, 17, 15, 18, 30, 68, 27, 74, 96, 24, 29, 86, 33, 74, 69, 68, 93,
         18, 42, 44, 90, 69, 51,  7, 22, 46, 38, 54, 15, 77, 40, 42, 16,  9, 32,
          2, 99, 95, 98, 89, 50, 54, 50, 89, 20, 99, 31, 99, 37, 71, 30, 47, 78,
         96, 49, 22,  7, 25,  4, 66, 20, 49, 46, 24, 38, 27, 31,  1, 34,  5, 72,
         98,  3, 39, 24, 92, 94, 82, 28, 65, 66,  5, 13, 16, 11, 93, 28,  5, 93,
         69, 39, 25, 56, 52, 13, 94, 20, 97, 63, 85, 44, 49, 55, 25, 39, 82, 20,
         60, 32, 22, 97, 42, 24, 59, 62, 54, 89, 53, 19, 10, 51,  8,  4, 50, 82,
          5, 76, 77, 27, 41, 43, 46, 23, 54, 36,  9,  3, 28, 55, 91,  6, 98, 94,
         20, 46, 97, 18, 94, 42,  4, 72, 40, 68, 25, 24, 81, 45],
        [23, 70, 24, 51, 68, 55, 88,  9, 94, 41,  2, 14, 27, 27, 44, 66, 28,  4,
         15, 14, 61, 88, 76, 33, 75, 36, 11, 54, 26, 70,  8, 54, 56, 97, 14, 94,
         82, 39, 73, 83, 15,  3, 40, 46, 18, 96, 15, 11, 74, 20,  6, 38, 14,  5,
         58, 17,  5, 22, 48, 81, 44, 47, 95, 26, 20,  0, 57, 98, 84, 49, 49, 56,
         82, 45, 91, 55,  6, 52,  6, 35, 20, 11,  2, 36, 23, 25, 38, 87, 68, 24,
         17, 45, 74, 15, 76, 10, 36, 40, 72, 35, 14, 11, 20, 89, 13, 50, 92, 41,
         16, 98,  4, 13, 50, 54, 29, 12, 11, 70, 93, 93,  5, 78, 37, 50, 87, 89,
         58, 58, 25, 78, 48, 33,  2, 67, 72, 59, 74, 47, 47, 67,  8, 77, 36, 83,
         67,  4, 20,  5, 62, 56, 28, 16, 85, 77,  2, 70, 63, 84, 92, 14, 25, 24,
         85, 82, 65, 35,  2, 49, 73, 32, 54, 56, 62, 27, 46, 18,  9,  8, 24, 87,
         75, 41, 42, 32,  9, 61,  0, 18, 79, 20, 24, 95, 54, 25,  8,  1, 34, 87,
         28, 32, 21, 45, 12, 65, 54, 77, 68, 34, 26,  1, 78, 37, 27, 72, 60, 80,
         96, 14, 94, 13, 16, 74, 31, 59, 78, 13, 61, 86, 46, 60, 17, 88, 19, 31,
         43, 13, 53, 95, 34, 41,  2, 61, 84, 79, 41, 97, 23, 82, 31, 27, 64, 58,
          5, 87, 63, 88, 15, 65, 40, 78, 57, 98,  3, 80,  9, 18, 66, 50, 29, 11,
         74, 88, 59,  5, 63, 99, 46, 86, 51, 91, 59, 76, 19,  5, 85, 92,  0, 51,
         38, 39,  8, 46, 23, 52, 29, 22, 74, 51, 24,  2, 59, 41, 36, 98, 46, 65,
         71, 55, 88, 18, 47, 62, 59, 30, 45, 59, 38, 11, 65, 40,  7, 89, 67, 20,
         27, 96, 97, 60, 17, 94, 43, 85, 56, 63, 40, 94, 92, 33, 57, 22, 54, 86,
         60, 58, 92, 81, 86, 66, 91, 29, 85, 83, 65, 65, 35, 99, 63, 54, 86, 74,
          7, 52, 55, 91, 35,  7, 52, 72, 60, 84, 48, 17, 69, 20, 51, 92, 38, 71,
         11, 20, 12, 90, 26, 39, 42, 12,  7, 95, 57, 93, 44, 43, 53, 92, 42, 17,
         13, 72, 55, 20,  3, 16, 82, 56, 86, 75, 30, 55, 95, 77, 64,  3, 18, 20,
         72, 94, 70, 60,  4, 39, 71, 11, 43, 71, 53, 31, 79, 64,  7, 99, 19, 44,
          4, 29, 66, 31, 65, 84, 67,  9, 49, 81, 28,  3, 58, 25, 43, 43, 18, 60,
         39, 25, 45, 37, 49, 48, 26, 42, 43, 94, 21, 31, 20, 64, 67, 75,  5, 78,
         14, 55, 86, 58, 68, 35, 25, 99, 24,  0, 29, 38,  6, 12, 89, 86, 55, 83,
         21, 17, 23, 66, 78, 82, 32, 84, 90, 91, 43,  2,  7, 48]]))
('edge_attr', array([2, 2, 0, 1, 1, 2, 2, 2, 0, 0, 1, 0, 0, 0, 1, 1, 0, 2, 0, 0, 2, 1,
       1, 0, 2, 1, 0, 0, 0, 2, 1, 1, 1, 2, 0, 1, 0, 1, 1, 1, 1, 2, 0, 2,
       1, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0, 1, 0, 2, 1, 1, 1, 1, 2, 1, 2, 1,
       0, 0, 2, 1, 1, 2, 1, 2, 0, 1, 2, 1, 0, 2, 0, 0, 0, 0, 1, 2, 2, 1,
       0, 2, 1, 0, 2, 0, 1, 1, 0, 2, 0, 0, 1, 0, 1, 2, 2, 0, 2, 2, 0, 1,
       2, 1, 0, 1, 2, 0, 2, 0, 1, 2, 2, 0, 1, 1, 2, 1, 2, 1, 1, 0, 2, 0,
       1, 2, 1, 2, 0, 2, 1, 1, 1, 2, 2, 1, 0, 0, 1, 2, 0, 2, 1, 1, 1, 2,
       2, 1, 0, 0, 2, 1, 1, 2, 1, 2, 1, 0, 2, 2, 2, 1, 1, 2, 2, 1, 2, 0,
       1, 0, 0, 0, 1, 1, 1, 2, 0, 0, 1, 2, 2, 0, 1, 1, 2, 2, 1, 2, 1, 0,
       2, 0, 1, 0, 1, 1, 0, 2, 2, 2, 0, 0, 1, 1, 1, 2, 0, 0, 1, 1, 1, 0,
       0, 0, 2, 0, 0, 2, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2, 0, 0, 0, 2, 1, 0,
       0, 2, 0, 2, 0, 2, 0, 1, 2, 0, 0, 1, 0, 0, 0, 2, 2, 0, 2, 2, 0, 1,
       0, 2, 2, 0, 0, 1, 0, 1, 2, 1, 2, 1, 1, 0, 2, 2, 1, 1, 2, 0, 2, 0,
       0, 2, 2, 0, 0, 1, 1, 1, 0, 0, 2, 2, 1, 2, 2, 2, 0, 1, 1, 2, 0, 2,
       0, 2, 1, 2, 1, 2, 2, 0, 1, 1, 0, 1, 2, 2, 1, 0, 0, 2, 1, 1, 2, 2,
       0, 1, 2, 1, 2, 2, 2, 1, 0, 0, 1, 1, 0, 0, 2, 1, 1, 1, 2, 0, 0, 2,
       0, 0, 2, 2, 0, 0, 0, 0, 2, 2, 0, 0, 1, 1, 0, 0, 1, 2, 0, 2, 2, 2,
       2, 0, 0, 0, 2, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 0, 1, 0, 1, 1, 1, 0,
       2, 2, 2, 1, 1, 1, 2, 2, 1, 0, 1, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2,
       1, 0, 2, 1, 2, 2, 1, 1, 2, 1, 2, 0, 0, 0, 2, 2, 2, 0, 1, 1, 0, 1,
       1, 0, 2, 1, 0, 0, 0, 2, 2, 2, 2, 2, 0, 0, 2, 1, 1, 1, 2, 1, 2, 1,
       1, 2, 1, 2, 0, 1, 2, 2, 2, 1, 1, 2, 0, 2, 2, 2, 2, 1, 0, 2, 1, 0,
       1, 2, 0, 2, 0, 0, 1, 2, 2, 2, 2, 2, 1, 0, 1, 0]))
('y', tensor([0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0,
        0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0,
        0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1,
        1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0]))
[9]:
vis = to_networkx(graph)

node_labels = graph.y.numpy()

import matplotlib.pyplot as plt
plt.figure(1,figsize=(15,13))
nx.draw(vis, cmap=plt.get_cmap('Set3'),node_color = node_labels,node_size=70,linewidths=6)
plt.show()
../../_images/ipynbs_pyg_tutorial_project_Tutorial14_14_0.png

Batch

With the Batch object we can represent multiple graphs as a single disconnected graph

[10]:
graph2 = graph
[11]:
batch = data.Batch().from_data_list([graph, graph2])
[12]:
print("Number of graphs:",batch.num_graphs)
print("Graph at index 1:",batch[1])
print("Retrieve the list of graphs:\n",len(batch.to_data_list()))
Number of graphs: 2
Graph at index 1: Data(x=[100, 16], edge_index=[2, 500], edge_attr=[500], y=[100])
Retrieve the list of graphs:
 2

Cluster

ClusterData groups the nodes of a graph into a specific number of cluster for faster computation in large graphs, then use ClusterLoader to load batches of clusters

[13]:
#cluster = data.ClusterData(graph, 5)
[14]:
#clusterloader = data.ClusterLoader(cluster)

Sampler

For each convolutional layer, sample a maximum of nodes from each neighborhood (as in GraphSAGE)

[15]:
sampler = data.NeighborSampler(graph.edge_index, sizes=[3,10], batch_size=4,
                                  shuffle=False)
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.NeighborSampler' is deprecated, use 'loader.NeighborSampler' instead
  warnings.warn(out)
[16]:
for s in sampler:
    print(s)
    break
(4, tensor([ 0,  1,  2,  3, 75, 36, 30, 79, 21, 14, 29, 34, 44, 25, 41, 96, 91, 90,
        72, 24, 52, 45, 83, 69, 46,  4, 88, 86, 82, 18, 97, 16, 53, 37, 20, 85,
        60, 12,  5, 39, 27,  9, 71, 65, 47, 99, 94, 93, 17, 70, 55, 32, 10]), [EdgeIndex(edge_index=tensor([[ 4,  5,  6, 14,  7,  8,  4,  9, 10, 15, 16, 17, 18, 19, 11, 12, 13, 20,
         21, 22, 23, 24, 25, 11, 15, 26, 27, 28, 19, 29, 10, 30, 31, 32, 33, 34,
          3,  7, 33, 35, 35, 36, 37, 38,  3,  5, 22, 39, 40, 41, 42, 43, 44, 45,
         46, 47, 48, 24, 34, 42, 49, 50, 51, 52],
        [ 0,  0,  0,  0,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,
          3,  4,  4,  4,  4,  5,  5,  5,  5,  5,  6,  6,  7,  7,  7,  8,  8,  8,
          9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 11, 11, 11, 12,
         12, 12, 12, 13, 13, 13, 13, 13, 13, 13]]), e_id=tensor([286, 477, 186,  65, 195, 209, 166, 132,  82,  10, 240, 299, 154, 497,
        411, 443, 400, 262,  41,  24, 180, 405, 465, 302,  25,  96, 142,  83,
        406, 313, 243, 188, 426, 460, 200, 486,  52,  11,  19, 159,  34, 217,
        100, 468, 294, 114, 268, 433, 349, 478, 196, 238, 207, 390,  14, 431,
         60, 474,  85, 128, 193, 445, 451, 160]), size=(53, 14)), EdgeIndex(edge_index=tensor([[ 4,  5,  6,  7,  8,  4,  9, 10, 11, 12, 13],
        [ 0,  0,  0,  1,  1,  2,  2,  2,  3,  3,  3]]), e_id=tensor([286, 477, 186, 195, 209, 166, 132,  82, 411, 443, 400]), size=(14, 4))])
[17]:
print("Batch size:", s[0])
print("Number of unique nodes involved in the sampling:",len(s[1]))
print("Number of neighbors sampled:", len(s[2][0].edge_index[0]), len(s[2][1].edge_index[0]))
Batch size: 4
Number of unique nodes involved in the sampling: 53
Number of neighbors sampled: 64 11

Datasets

List all the available datasets

[18]:
datasets.__all__
[18]:
['KarateClub',
 'TUDataset',
 'GNNBenchmarkDataset',
 'Planetoid',
 'FakeDataset',
 'FakeHeteroDataset',
 'NELL',
 'CitationFull',
 'CoraFull',
 'Coauthor',
 'Amazon',
 'PPI',
 'Reddit',
 'Reddit2',
 'Flickr',
 'Yelp',
 'AmazonProducts',
 'QM7b',
 'QM9',
 'MD17',
 'ZINC',
 'AQSOL',
 'MoleculeNet',
 'Entities',
 'RelLinkPredDataset',
 'GEDDataset',
 'AttributedGraphDataset',
 'MNISTSuperpixels',
 'FAUST',
 'DynamicFAUST',
 'ShapeNet',
 'ModelNet',
 'CoMA',
 'SHREC2016',
 'TOSCA',
 'PCPNetDataset',
 'S3DIS',
 'GeometricShapes',
 'BitcoinOTC',
 'ICEWS18',
 'GDELT',
 'DBP15K',
 'WILLOWObjectClass',
 'PascalVOCKeypoints',
 'PascalPF',
 'SNAPDataset',
 'SuiteSparseMatrixCollection',
 'AMiner',
 'WordNet18',
 'WordNet18RR',
 'WikiCS',
 'WebKB',
 'WikipediaNetwork',
 'Actor',
 'OGB_MAG',
 'DBLP',
 'MovieLens',
 'IMDB',
 'LastFM',
 'HGBDataset',
 'JODIEDataset',
 'MixHopSyntheticDataset',
 'UPFD',
 'GitHub',
 'FacebookPagePage',
 'LastFMAsia',
 'DeezerEurope',
 'GemsecDeezer',
 'Twitch',
 'Airports',
 'BAShapes',
 'MalNetTiny',
 'OMDB',
 'PolBlogs',
 'EmailEUCore',
 'StochasticBlockModelDataset',
 'RandomPartitionGraphDataset',
 'LINKXDataset',
 'EllipticBitcoinDataset']
[20]:
name = 'Cora'
transform = transforms.Compose([
    transforms.RandomNodeSplit('train_rest', num_val=500, num_test=500),
    transforms.TargetIndegree(),
])
cora = datasets.Planetoid('./data', name, pre_transform=transforms.NormalizeFeatures(), transform=transform)
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
[21]:
aids = datasets.TUDataset(root="./data", name="AIDS")
Downloading https://www.chrsmrrs.com/graphkerneldatasets/AIDS.zip
Extracting data/AIDS/AIDS.zip
Processing...
Done!
[22]:
print("AIDS info:")
print('# of graphs:', len(aids))
print('# Classes (graphs)', aids.num_classes)
print('# Edge features', aids.num_edge_features)
print('# Edge labels', aids.num_edge_labels)
print('# Node features', aids.num_node_features)
AIDS info:
# of graphs: 2000
# Classes (graphs) 2
# Edge features 3
# Edge labels 3
# Node features 38
[23]:
print("Cora info:")
print('# of graphs:', len(cora))
print('# Classes (nodes)', cora.num_classes)
print('# Edge features', cora.num_edge_features)
print('# Node features', cora.num_node_features)
Cora info:
# of graphs: 1
# Classes (nodes) 7
# Edge features 1
# Node features 1433
[24]:
aids.data
[24]:
Data(x=[31385, 38], edge_index=[2, 64780], edge_attr=[64780, 3], y=[2000])
[25]:
aids[0]
[25]:
Data(edge_index=[2, 106], x=[47, 38], edge_attr=[106, 3], y=[1])
[26]:
cora.data
[26]:
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
[27]:
cora[0]
[27]:
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_attr=[10556, 1])
[28]:
cora_loader = data.DataLoader(cora)
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead
  warnings.warn(out)
[29]:
for l in cora_loader:
    print(l)
    break
DataBatch(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_attr=[10556, 1], batch=[2708], ptr=[2])