Tutorial 7: ARGA & ARVGA

Paper: * Adversarially Regularized Graph Autoencoder for Graph Embedding

Code: * ARGA & ARVGA * Example on clustering

[ ]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.12.1+cu113
     |████████████████████████████████| 7.9 MB 9.3 MB/s
     |████████████████████████████████| 3.5 MB 9.6 MB/s
  Building wheel for torch-geometric (setup.py) ... done

Imports

[ ]:
import os.path as osp
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import (v_measure_score, homogeneity_score, completeness_score)
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.nn.models.autoencoder import ARGVA
from torch_geometric.utils import train_test_split_edges
[ ]:
use_cuda = False

Define the dataset

Download the dataset

[ ]:
dataset = 'Cora'
path = osp.join('.', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset.get(0)
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!

Get the number of nodes

[ ]:
num_nodes = data.x.shape[0]

Create the train/val/test data

[ ]:
data.train_mask = data.val_mask = data.test_mask = None
data = train_test_split_edges(data)
data
/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'train_test_split_edges' is deprecated, use 'transforms.RandomLinkSplit' instead
  warnings.warn(out)
Data(x=[2708, 1433], y=[2708], val_pos_edge_index=[2, 263], test_pos_edge_index=[2, 527], train_pos_edge_index=[2, 8976], train_neg_adj_mask=[2708, 2708], val_neg_edge_index=[2, 263], test_neg_edge_index=[2, 527])

Define the model

Define the encoder classes (the same as in Tutorial 6)

[ ]:
class VEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

Define the discriminator class

[ ]:
class Discriminator(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(Discriminator, self).__init__()
        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin3 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x

Define the training algorithm

[ ]:
def train():
    model.train()
    encoder_optimizer.zero_grad()

    z = model.encode(data.x, data.train_pos_edge_index)

    for i in range(5):
        idx = range(num_nodes)
        discriminator.train()
        discriminator_optimizer.zero_grad()
        discriminator_loss = model.discriminator_loss(z[idx]) # Comment
        discriminator_loss.backward(retain_graph=True)
        discriminator_optimizer.step()

    loss = 0
    loss = loss + model.reg_loss(z)  # Comment

    loss = loss + model.recon_loss(z, data.train_pos_edge_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()

    encoder_optimizer.step()

    return loss

Define a training test method

[ ]:
@torch.no_grad()
def test():
    model.eval()
    z = model.encode(data.x, data.train_pos_edge_index)

    # Cluster embedded values using k-means.
    kmeans_input = z.cpu().numpy()
    kmeans = KMeans(n_clusters=7, random_state=0).fit(kmeans_input)
    pred = kmeans.predict(kmeans_input)

    labels = data.y.cpu().numpy()
    completeness = completeness_score(labels, pred)
    hm = homogeneity_score(labels, pred)
    nmi = v_measure_score(labels, pred)

    auc, ap = model.test(z, data.test_pos_edge_index, data.test_neg_edge_index)

    return auc, ap, completeness, hm, nmi

Initialize the model

Initialize an encoder and a discriminator

[ ]:
latent_size = 32
encoder = VEncoder(data.num_features, out_channels=latent_size)

discriminator = Discriminator(in_channels=latent_size, hidden_channels=64,
                              out_channels=1) # Comment

Initialize the model and move everything to the GPU

[ ]:
model = ARGVA(encoder, discriminator)

device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
model, data = model.to(device), data.to(device)

Define the optimizers

[ ]:
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)
encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)

Train the model

[ ]:
for epoch in range(1, 201):
    loss = train()
    auc, ap, completeness, hm, nmi = test()
    print((f'Epoch: {epoch:03d}, Loss: {loss:.3f}, AUC: {auc:.3f}, '
           f'AP: {ap:.3f}, Completeness: {completeness:.3f}, '
           f'Homogeneity: {hm:.3f}, NMI: {nmi:.3f}'))
Epoch: 001, Loss: 5.541, AUC: 0.691, AP: 0.708, Completeness: 0.070, Homogeneity: 0.069, NMI: 0.069
Epoch: 002, Loss: 4.885, AUC: 0.677, AP: 0.702, Completeness: 0.051, Homogeneity: 0.049, NMI: 0.050
Epoch: 003, Loss: 4.419, AUC: 0.677, AP: 0.703, Completeness: 0.037, Homogeneity: 0.034, NMI: 0.036
Epoch: 004, Loss: 4.100, AUC: 0.679, AP: 0.706, Completeness: 0.040, Homogeneity: 0.037, NMI: 0.038
Epoch: 005, Loss: 3.796, AUC: 0.684, AP: 0.710, Completeness: 0.045, Homogeneity: 0.041, NMI: 0.043
Epoch: 006, Loss: 3.474, AUC: 0.693, AP: 0.718, Completeness: 0.084, Homogeneity: 0.079, NMI: 0.081
Epoch: 007, Loss: 3.168, AUC: 0.702, AP: 0.727, Completeness: 0.092, Homogeneity: 0.089, NMI: 0.091
Epoch: 008, Loss: 2.958, AUC: 0.712, AP: 0.735, Completeness: 0.098, Homogeneity: 0.095, NMI: 0.096
Epoch: 009, Loss: 2.735, AUC: 0.717, AP: 0.738, Completeness: 0.105, Homogeneity: 0.101, NMI: 0.103
Epoch: 010, Loss: 2.636, AUC: 0.717, AP: 0.737, Completeness: 0.110, Homogeneity: 0.107, NMI: 0.109
Epoch: 011, Loss: 2.548, AUC: 0.717, AP: 0.736, Completeness: 0.123, Homogeneity: 0.119, NMI: 0.121
Epoch: 012, Loss: 2.469, AUC: 0.720, AP: 0.736, Completeness: 0.138, Homogeneity: 0.138, NMI: 0.138
Epoch: 013, Loss: 2.399, AUC: 0.726, AP: 0.739, Completeness: 0.152, Homogeneity: 0.151, NMI: 0.151
Epoch: 014, Loss: 2.358, AUC: 0.724, AP: 0.739, Completeness: 0.163, Homogeneity: 0.158, NMI: 0.160
Epoch: 015, Loss: 2.334, AUC: 0.717, AP: 0.733, Completeness: 0.161, Homogeneity: 0.159, NMI: 0.160
Epoch: 016, Loss: 2.393, AUC: 0.712, AP: 0.730, Completeness: 0.141, Homogeneity: 0.136, NMI: 0.139
Epoch: 017, Loss: 2.529, AUC: 0.713, AP: 0.731, Completeness: 0.172, Homogeneity: 0.168, NMI: 0.170
Epoch: 018, Loss: 2.679, AUC: 0.722, AP: 0.739, Completeness: 0.198, Homogeneity: 0.196, NMI: 0.197
Epoch: 019, Loss: 2.736, AUC: 0.741, AP: 0.755, Completeness: 0.217, Homogeneity: 0.213, NMI: 0.215
Epoch: 020, Loss: 2.693, AUC: 0.762, AP: 0.774, Completeness: 0.235, Homogeneity: 0.228, NMI: 0.231
Epoch: 021, Loss: 2.660, AUC: 0.765, AP: 0.776, Completeness: 0.221, Homogeneity: 0.212, NMI: 0.216
Epoch: 022, Loss: 2.692, AUC: 0.754, AP: 0.768, Completeness: 0.223, Homogeneity: 0.204, NMI: 0.213
Epoch: 023, Loss: 2.868, AUC: 0.746, AP: 0.763, Completeness: 0.214, Homogeneity: 0.201, NMI: 0.207
Epoch: 024, Loss: 3.095, AUC: 0.745, AP: 0.764, Completeness: 0.196, Homogeneity: 0.192, NMI: 0.194
Epoch: 025, Loss: 3.238, AUC: 0.752, AP: 0.770, Completeness: 0.223, Homogeneity: 0.213, NMI: 0.218
Epoch: 026, Loss: 3.317, AUC: 0.764, AP: 0.781, Completeness: 0.272, Homogeneity: 0.263, NMI: 0.267
Epoch: 027, Loss: 3.291, AUC: 0.777, AP: 0.793, Completeness: 0.256, Homogeneity: 0.255, NMI: 0.256
Epoch: 028, Loss: 3.261, AUC: 0.788, AP: 0.802, Completeness: 0.281, Homogeneity: 0.269, NMI: 0.275
Epoch: 029, Loss: 3.257, AUC: 0.791, AP: 0.803, Completeness: 0.268, Homogeneity: 0.256, NMI: 0.262
Epoch: 030, Loss: 3.356, AUC: 0.788, AP: 0.798, Completeness: 0.264, Homogeneity: 0.257, NMI: 0.260
Epoch: 031, Loss: 3.461, AUC: 0.786, AP: 0.793, Completeness: 0.277, Homogeneity: 0.255, NMI: 0.265
Epoch: 032, Loss: 3.513, AUC: 0.789, AP: 0.791, Completeness: 0.260, Homogeneity: 0.240, NMI: 0.250
Epoch: 033, Loss: 3.582, AUC: 0.795, AP: 0.791, Completeness: 0.268, Homogeneity: 0.246, NMI: 0.257
Epoch: 034, Loss: 3.648, AUC: 0.798, AP: 0.789, Completeness: 0.268, Homogeneity: 0.246, NMI: 0.256
Epoch: 035, Loss: 3.760, AUC: 0.795, AP: 0.787, Completeness: 0.268, Homogeneity: 0.242, NMI: 0.254
Epoch: 036, Loss: 3.879, AUC: 0.789, AP: 0.785, Completeness: 0.287, Homogeneity: 0.258, NMI: 0.272
Epoch: 037, Loss: 3.949, AUC: 0.786, AP: 0.787, Completeness: 0.299, Homogeneity: 0.268, NMI: 0.282
Epoch: 038, Loss: 3.989, AUC: 0.787, AP: 0.790, Completeness: 0.305, Homogeneity: 0.272, NMI: 0.287
Epoch: 039, Loss: 4.134, AUC: 0.792, AP: 0.796, Completeness: 0.313, Homogeneity: 0.275, NMI: 0.293
Epoch: 040, Loss: 4.237, AUC: 0.798, AP: 0.803, Completeness: 0.315, Homogeneity: 0.286, NMI: 0.300
Epoch: 041, Loss: 4.137, AUC: 0.804, AP: 0.809, Completeness: 0.332, Homogeneity: 0.291, NMI: 0.310
Epoch: 042, Loss: 4.156, AUC: 0.808, AP: 0.812, Completeness: 0.341, Homogeneity: 0.303, NMI: 0.321
Epoch: 043, Loss: 4.106, AUC: 0.809, AP: 0.812, Completeness: 0.343, Homogeneity: 0.324, NMI: 0.333
Epoch: 044, Loss: 4.102, AUC: 0.808, AP: 0.810, Completeness: 0.351, Homogeneity: 0.333, NMI: 0.342
Epoch: 045, Loss: 4.084, AUC: 0.806, AP: 0.807, Completeness: 0.359, Homogeneity: 0.344, NMI: 0.351
Epoch: 046, Loss: 4.026, AUC: 0.805, AP: 0.805, Completeness: 0.362, Homogeneity: 0.347, NMI: 0.354
Epoch: 047, Loss: 4.030, AUC: 0.805, AP: 0.805, Completeness: 0.365, Homogeneity: 0.351, NMI: 0.358
Epoch: 048, Loss: 4.008, AUC: 0.808, AP: 0.807, Completeness: 0.372, Homogeneity: 0.360, NMI: 0.366
Epoch: 049, Loss: 3.979, AUC: 0.812, AP: 0.810, Completeness: 0.373, Homogeneity: 0.360, NMI: 0.367
Epoch: 050, Loss: 3.998, AUC: 0.817, AP: 0.814, Completeness: 0.375, Homogeneity: 0.362, NMI: 0.368
Epoch: 051, Loss: 4.072, AUC: 0.822, AP: 0.816, Completeness: 0.377, Homogeneity: 0.364, NMI: 0.370
Epoch: 052, Loss: 4.046, AUC: 0.825, AP: 0.818, Completeness: 0.379, Homogeneity: 0.364, NMI: 0.372
Epoch: 053, Loss: 4.086, AUC: 0.828, AP: 0.819, Completeness: 0.382, Homogeneity: 0.366, NMI: 0.374
Epoch: 054, Loss: 4.046, AUC: 0.829, AP: 0.819, Completeness: 0.390, Homogeneity: 0.374, NMI: 0.382
Epoch: 055, Loss: 4.131, AUC: 0.829, AP: 0.819, Completeness: 0.385, Homogeneity: 0.348, NMI: 0.366
Epoch: 056, Loss: 4.190, AUC: 0.828, AP: 0.817, Completeness: 0.391, Homogeneity: 0.351, NMI: 0.370
Epoch: 057, Loss: 4.164, AUC: 0.828, AP: 0.817, Completeness: 0.381, Homogeneity: 0.342, NMI: 0.360
Epoch: 058, Loss: 4.181, AUC: 0.827, AP: 0.815, Completeness: 0.350, Homogeneity: 0.343, NMI: 0.346
Epoch: 059, Loss: 4.249, AUC: 0.826, AP: 0.815, Completeness: 0.355, Homogeneity: 0.347, NMI: 0.351
Epoch: 060, Loss: 4.240, AUC: 0.824, AP: 0.813, Completeness: 0.381, Homogeneity: 0.345, NMI: 0.362
Epoch: 061, Loss: 4.244, AUC: 0.823, AP: 0.812, Completeness: 0.357, Homogeneity: 0.348, NMI: 0.352
Epoch: 062, Loss: 4.212, AUC: 0.822, AP: 0.812, Completeness: 0.355, Homogeneity: 0.347, NMI: 0.351
Epoch: 063, Loss: 4.203, AUC: 0.823, AP: 0.812, Completeness: 0.370, Homogeneity: 0.346, NMI: 0.358
Epoch: 064, Loss: 4.215, AUC: 0.823, AP: 0.813, Completeness: 0.372, Homogeneity: 0.352, NMI: 0.362
Epoch: 065, Loss: 4.099, AUC: 0.824, AP: 0.814, Completeness: 0.382, Homogeneity: 0.358, NMI: 0.370
Epoch: 066, Loss: 4.164, AUC: 0.825, AP: 0.815, Completeness: 0.380, Homogeneity: 0.360, NMI: 0.370
Epoch: 067, Loss: 4.110, AUC: 0.826, AP: 0.816, Completeness: 0.387, Homogeneity: 0.365, NMI: 0.376
Epoch: 068, Loss: 4.083, AUC: 0.827, AP: 0.817, Completeness: 0.392, Homogeneity: 0.370, NMI: 0.381
Epoch: 069, Loss: 4.102, AUC: 0.828, AP: 0.819, Completeness: 0.398, Homogeneity: 0.383, NMI: 0.390
Epoch: 070, Loss: 4.081, AUC: 0.830, AP: 0.821, Completeness: 0.404, Homogeneity: 0.390, NMI: 0.396
Epoch: 071, Loss: 4.066, AUC: 0.832, AP: 0.823, Completeness: 0.406, Homogeneity: 0.392, NMI: 0.399
Epoch: 072, Loss: 4.122, AUC: 0.834, AP: 0.826, Completeness: 0.406, Homogeneity: 0.391, NMI: 0.398
Epoch: 073, Loss: 4.030, AUC: 0.837, AP: 0.829, Completeness: 0.383, Homogeneity: 0.391, NMI: 0.387
Epoch: 074, Loss: 4.069, AUC: 0.840, AP: 0.832, Completeness: 0.415, Homogeneity: 0.400, NMI: 0.407
Epoch: 075, Loss: 4.204, AUC: 0.843, AP: 0.835, Completeness: 0.415, Homogeneity: 0.403, NMI: 0.409
Epoch: 076, Loss: 4.044, AUC: 0.846, AP: 0.837, Completeness: 0.429, Homogeneity: 0.408, NMI: 0.418
Epoch: 077, Loss: 4.168, AUC: 0.848, AP: 0.840, Completeness: 0.418, Homogeneity: 0.406, NMI: 0.412
Epoch: 078, Loss: 4.217, AUC: 0.851, AP: 0.842, Completeness: 0.418, Homogeneity: 0.414, NMI: 0.416
Epoch: 079, Loss: 4.152, AUC: 0.854, AP: 0.845, Completeness: 0.433, Homogeneity: 0.418, NMI: 0.426
Epoch: 080, Loss: 4.171, AUC: 0.857, AP: 0.848, Completeness: 0.434, Homogeneity: 0.419, NMI: 0.426
Epoch: 081, Loss: 4.152, AUC: 0.859, AP: 0.851, Completeness: 0.421, Homogeneity: 0.416, NMI: 0.418
Epoch: 082, Loss: 4.172, AUC: 0.862, AP: 0.853, Completeness: 0.425, Homogeneity: 0.425, NMI: 0.425
Epoch: 083, Loss: 4.107, AUC: 0.864, AP: 0.856, Completeness: 0.434, Homogeneity: 0.422, NMI: 0.428
Epoch: 084, Loss: 4.242, AUC: 0.866, AP: 0.859, Completeness: 0.436, Homogeneity: 0.424, NMI: 0.430
Epoch: 085, Loss: 4.062, AUC: 0.869, AP: 0.862, Completeness: 0.429, Homogeneity: 0.429, NMI: 0.429
Epoch: 086, Loss: 4.057, AUC: 0.871, AP: 0.864, Completeness: 0.432, Homogeneity: 0.424, NMI: 0.428
Epoch: 087, Loss: 4.028, AUC: 0.872, AP: 0.866, Completeness: 0.432, Homogeneity: 0.433, NMI: 0.432
Epoch: 088, Loss: 4.047, AUC: 0.874, AP: 0.868, Completeness: 0.436, Homogeneity: 0.429, NMI: 0.432
Epoch: 089, Loss: 4.004, AUC: 0.874, AP: 0.869, Completeness: 0.439, Homogeneity: 0.433, NMI: 0.436
Epoch: 090, Loss: 3.814, AUC: 0.875, AP: 0.870, Completeness: 0.441, Homogeneity: 0.435, NMI: 0.438
Epoch: 091, Loss: 3.919, AUC: 0.875, AP: 0.870, Completeness: 0.442, Homogeneity: 0.437, NMI: 0.440
Epoch: 092, Loss: 3.919, AUC: 0.876, AP: 0.871, Completeness: 0.442, Homogeneity: 0.437, NMI: 0.440
Epoch: 093, Loss: 3.943, AUC: 0.876, AP: 0.871, Completeness: 0.443, Homogeneity: 0.439, NMI: 0.441
Epoch: 094, Loss: 3.920, AUC: 0.875, AP: 0.871, Completeness: 0.444, Homogeneity: 0.439, NMI: 0.442
Epoch: 095, Loss: 3.912, AUC: 0.876, AP: 0.872, Completeness: 0.444, Homogeneity: 0.440, NMI: 0.442
Epoch: 096, Loss: 3.902, AUC: 0.877, AP: 0.874, Completeness: 0.446, Homogeneity: 0.441, NMI: 0.443
Epoch: 097, Loss: 3.894, AUC: 0.877, AP: 0.875, Completeness: 0.491, Homogeneity: 0.462, NMI: 0.476
Epoch: 098, Loss: 3.931, AUC: 0.879, AP: 0.876, Completeness: 0.492, Homogeneity: 0.463, NMI: 0.477
Epoch: 099, Loss: 3.890, AUC: 0.880, AP: 0.878, Completeness: 0.490, Homogeneity: 0.462, NMI: 0.476
Epoch: 100, Loss: 3.883, AUC: 0.880, AP: 0.879, Completeness: 0.492, Homogeneity: 0.464, NMI: 0.478
Epoch: 101, Loss: 3.873, AUC: 0.881, AP: 0.881, Completeness: 0.489, Homogeneity: 0.462, NMI: 0.475
Epoch: 102, Loss: 3.824, AUC: 0.881, AP: 0.882, Completeness: 0.490, Homogeneity: 0.465, NMI: 0.477
Epoch: 103, Loss: 3.846, AUC: 0.882, AP: 0.882, Completeness: 0.491, Homogeneity: 0.466, NMI: 0.478
Epoch: 104, Loss: 3.797, AUC: 0.882, AP: 0.883, Completeness: 0.492, Homogeneity: 0.466, NMI: 0.479
Epoch: 105, Loss: 3.863, AUC: 0.881, AP: 0.882, Completeness: 0.436, Homogeneity: 0.438, NMI: 0.437
Epoch: 106, Loss: 3.805, AUC: 0.881, AP: 0.882, Completeness: 0.493, Homogeneity: 0.469, NMI: 0.481
Epoch: 107, Loss: 3.804, AUC: 0.881, AP: 0.882, Completeness: 0.481, Homogeneity: 0.459, NMI: 0.470
Epoch: 108, Loss: 3.855, AUC: 0.881, AP: 0.883, Completeness: 0.497, Homogeneity: 0.472, NMI: 0.484
Epoch: 109, Loss: 3.757, AUC: 0.881, AP: 0.883, Completeness: 0.480, Homogeneity: 0.458, NMI: 0.468
Epoch: 110, Loss: 3.787, AUC: 0.882, AP: 0.884, Completeness: 0.497, Homogeneity: 0.472, NMI: 0.484
Epoch: 111, Loss: 3.785, AUC: 0.883, AP: 0.885, Completeness: 0.478, Homogeneity: 0.457, NMI: 0.467
Epoch: 112, Loss: 3.768, AUC: 0.883, AP: 0.886, Completeness: 0.479, Homogeneity: 0.458, NMI: 0.468
Epoch: 113, Loss: 3.861, AUC: 0.884, AP: 0.887, Completeness: 0.470, Homogeneity: 0.450, NMI: 0.460
Epoch: 114, Loss: 3.844, AUC: 0.883, AP: 0.887, Completeness: 0.440, Homogeneity: 0.433, NMI: 0.437
Epoch: 115, Loss: 3.839, AUC: 0.884, AP: 0.887, Completeness: 0.461, Homogeneity: 0.442, NMI: 0.452
Epoch: 116, Loss: 3.826, AUC: 0.883, AP: 0.887, Completeness: 0.430, Homogeneity: 0.424, NMI: 0.427
Epoch: 117, Loss: 3.865, AUC: 0.883, AP: 0.887, Completeness: 0.453, Homogeneity: 0.437, NMI: 0.445
Epoch: 118, Loss: 3.852, AUC: 0.883, AP: 0.887, Completeness: 0.466, Homogeneity: 0.447, NMI: 0.456
Epoch: 119, Loss: 3.867, AUC: 0.884, AP: 0.887, Completeness: 0.442, Homogeneity: 0.426, NMI: 0.434
Epoch: 120, Loss: 3.817, AUC: 0.884, AP: 0.888, Completeness: 0.463, Homogeneity: 0.444, NMI: 0.453
Epoch: 121, Loss: 3.834, AUC: 0.884, AP: 0.888, Completeness: 0.441, Homogeneity: 0.426, NMI: 0.434
Epoch: 122, Loss: 3.874, AUC: 0.885, AP: 0.889, Completeness: 0.463, Homogeneity: 0.445, NMI: 0.454
Epoch: 123, Loss: 3.886, AUC: 0.886, AP: 0.890, Completeness: 0.465, Homogeneity: 0.448, NMI: 0.456
Epoch: 124, Loss: 3.813, AUC: 0.887, AP: 0.891, Completeness: 0.462, Homogeneity: 0.445, NMI: 0.454
Epoch: 125, Loss: 3.849, AUC: 0.888, AP: 0.892, Completeness: 0.421, Homogeneity: 0.418, NMI: 0.420
Epoch: 126, Loss: 3.838, AUC: 0.888, AP: 0.893, Completeness: 0.424, Homogeneity: 0.422, NMI: 0.423
Epoch: 127, Loss: 3.773, AUC: 0.889, AP: 0.893, Completeness: 0.422, Homogeneity: 0.420, NMI: 0.421
Epoch: 128, Loss: 3.822, AUC: 0.889, AP: 0.894, Completeness: 0.420, Homogeneity: 0.418, NMI: 0.419
Epoch: 129, Loss: 3.843, AUC: 0.890, AP: 0.894, Completeness: 0.432, Homogeneity: 0.430, NMI: 0.431
Epoch: 130, Loss: 3.772, AUC: 0.890, AP: 0.895, Completeness: 0.432, Homogeneity: 0.430, NMI: 0.431
Epoch: 131, Loss: 3.748, AUC: 0.891, AP: 0.895, Completeness: 0.419, Homogeneity: 0.418, NMI: 0.418
Epoch: 132, Loss: 3.762, AUC: 0.891, AP: 0.896, Completeness: 0.428, Homogeneity: 0.427, NMI: 0.428
Epoch: 133, Loss: 3.855, AUC: 0.892, AP: 0.897, Completeness: 0.431, Homogeneity: 0.430, NMI: 0.431
Epoch: 134, Loss: 3.768, AUC: 0.893, AP: 0.898, Completeness: 0.434, Homogeneity: 0.434, NMI: 0.434
Epoch: 135, Loss: 3.712, AUC: 0.893, AP: 0.898, Completeness: 0.433, Homogeneity: 0.432, NMI: 0.432
Epoch: 136, Loss: 3.742, AUC: 0.893, AP: 0.898, Completeness: 0.434, Homogeneity: 0.434, NMI: 0.434
Epoch: 137, Loss: 3.752, AUC: 0.893, AP: 0.898, Completeness: 0.431, Homogeneity: 0.431, NMI: 0.431
Epoch: 138, Loss: 3.747, AUC: 0.893, AP: 0.898, Completeness: 0.428, Homogeneity: 0.430, NMI: 0.429
Epoch: 139, Loss: 3.709, AUC: 0.893, AP: 0.898, Completeness: 0.413, Homogeneity: 0.418, NMI: 0.415
Epoch: 140, Loss: 3.766, AUC: 0.892, AP: 0.897, Completeness: 0.425, Homogeneity: 0.427, NMI: 0.426
Epoch: 141, Loss: 3.757, AUC: 0.892, AP: 0.896, Completeness: 0.405, Homogeneity: 0.412, NMI: 0.408
Epoch: 142, Loss: 3.652, AUC: 0.892, AP: 0.896, Completeness: 0.429, Homogeneity: 0.434, NMI: 0.431
Epoch: 143, Loss: 3.719, AUC: 0.892, AP: 0.897, Completeness: 0.403, Homogeneity: 0.411, NMI: 0.407
Epoch: 144, Loss: 3.801, AUC: 0.894, AP: 0.899, Completeness: 0.437, Homogeneity: 0.439, NMI: 0.438
Epoch: 145, Loss: 3.719, AUC: 0.894, AP: 0.900, Completeness: 0.401, Homogeneity: 0.409, NMI: 0.405
Epoch: 146, Loss: 3.731, AUC: 0.895, AP: 0.901, Completeness: 0.399, Homogeneity: 0.407, NMI: 0.403
Epoch: 147, Loss: 3.626, AUC: 0.895, AP: 0.902, Completeness: 0.408, Homogeneity: 0.414, NMI: 0.411
Epoch: 148, Loss: 3.660, AUC: 0.895, AP: 0.902, Completeness: 0.382, Homogeneity: 0.387, NMI: 0.385
Epoch: 149, Loss: 3.626, AUC: 0.895, AP: 0.903, Completeness: 0.396, Homogeneity: 0.409, NMI: 0.402
Epoch: 150, Loss: 3.638, AUC: 0.895, AP: 0.902, Completeness: 0.351, Homogeneity: 0.357, NMI: 0.354
Epoch: 151, Loss: 3.667, AUC: 0.895, AP: 0.902, Completeness: 0.391, Homogeneity: 0.406, NMI: 0.399
Epoch: 152, Loss: 3.661, AUC: 0.895, AP: 0.902, Completeness: 0.371, Homogeneity: 0.387, NMI: 0.379
Epoch: 153, Loss: 3.656, AUC: 0.895, AP: 0.901, Completeness: 0.404, Homogeneity: 0.415, NMI: 0.409
Epoch: 154, Loss: 3.664, AUC: 0.895, AP: 0.901, Completeness: 0.390, Homogeneity: 0.407, NMI: 0.398
Epoch: 155, Loss: 3.614, AUC: 0.896, AP: 0.902, Completeness: 0.391, Homogeneity: 0.406, NMI: 0.398
Epoch: 156, Loss: 3.656, AUC: 0.897, AP: 0.903, Completeness: 0.402, Homogeneity: 0.417, NMI: 0.410
Epoch: 157, Loss: 3.554, AUC: 0.898, AP: 0.905, Completeness: 0.400, Homogeneity: 0.414, NMI: 0.407
Epoch: 158, Loss: 3.628, AUC: 0.899, AP: 0.906, Completeness: 0.404, Homogeneity: 0.417, NMI: 0.411
Epoch: 159, Loss: 3.538, AUC: 0.899, AP: 0.907, Completeness: 0.404, Homogeneity: 0.421, NMI: 0.412
Epoch: 160, Loss: 3.566, AUC: 0.899, AP: 0.907, Completeness: 0.413, Homogeneity: 0.424, NMI: 0.419
Epoch: 161, Loss: 3.491, AUC: 0.899, AP: 0.907, Completeness: 0.399, Homogeneity: 0.415, NMI: 0.407
Epoch: 162, Loss: 3.574, AUC: 0.898, AP: 0.907, Completeness: 0.370, Homogeneity: 0.380, NMI: 0.375
Epoch: 163, Loss: 3.555, AUC: 0.898, AP: 0.907, Completeness: 0.403, Homogeneity: 0.418, NMI: 0.410
Epoch: 164, Loss: 3.581, AUC: 0.898, AP: 0.907, Completeness: 0.387, Homogeneity: 0.393, NMI: 0.390
Epoch: 165, Loss: 3.543, AUC: 0.898, AP: 0.907, Completeness: 0.455, Homogeneity: 0.472, NMI: 0.463
Epoch: 166, Loss: 3.576, AUC: 0.899, AP: 0.907, Completeness: 0.445, Homogeneity: 0.448, NMI: 0.446
Epoch: 167, Loss: 3.483, AUC: 0.900, AP: 0.908, Completeness: 0.404, Homogeneity: 0.418, NMI: 0.411
Epoch: 168, Loss: 3.598, AUC: 0.901, AP: 0.908, Completeness: 0.403, Homogeneity: 0.424, NMI: 0.413
Epoch: 169, Loss: 3.535, AUC: 0.901, AP: 0.909, Completeness: 0.450, Homogeneity: 0.456, NMI: 0.453
Epoch: 170, Loss: 3.513, AUC: 0.901, AP: 0.910, Completeness: 0.437, Homogeneity: 0.428, NMI: 0.432
Epoch: 171, Loss: 3.537, AUC: 0.902, AP: 0.910, Completeness: 0.367, Homogeneity: 0.376, NMI: 0.372
Epoch: 172, Loss: 3.514, AUC: 0.902, AP: 0.910, Completeness: 0.400, Homogeneity: 0.420, NMI: 0.410
Epoch: 173, Loss: 3.512, AUC: 0.902, AP: 0.910, Completeness: 0.365, Homogeneity: 0.375, NMI: 0.370
Epoch: 174, Loss: 3.500, AUC: 0.902, AP: 0.910, Completeness: 0.441, Homogeneity: 0.435, NMI: 0.438
Epoch: 175, Loss: 3.518, AUC: 0.903, AP: 0.910, Completeness: 0.394, Homogeneity: 0.390, NMI: 0.392
Epoch: 176, Loss: 3.515, AUC: 0.904, AP: 0.911, Completeness: 0.375, Homogeneity: 0.383, NMI: 0.379
Epoch: 177, Loss: 3.495, AUC: 0.905, AP: 0.911, Completeness: 0.387, Homogeneity: 0.395, NMI: 0.391
Epoch: 178, Loss: 3.581, AUC: 0.904, AP: 0.911, Completeness: 0.388, Homogeneity: 0.396, NMI: 0.392
Epoch: 179, Loss: 3.491, AUC: 0.903, AP: 0.910, Completeness: 0.388, Homogeneity: 0.396, NMI: 0.392
Epoch: 180, Loss: 3.589, AUC: 0.902, AP: 0.909, Completeness: 0.381, Homogeneity: 0.387, NMI: 0.384
Epoch: 181, Loss: 3.486, AUC: 0.900, AP: 0.907, Completeness: 0.385, Homogeneity: 0.394, NMI: 0.389
Epoch: 182, Loss: 3.592, AUC: 0.899, AP: 0.905, Completeness: 0.379, Homogeneity: 0.393, NMI: 0.386
Epoch: 183, Loss: 3.544, AUC: 0.897, AP: 0.903, Completeness: 0.406, Homogeneity: 0.423, NMI: 0.414
Epoch: 184, Loss: 3.511, AUC: 0.896, AP: 0.902, Completeness: 0.371, Homogeneity: 0.383, NMI: 0.377
Epoch: 185, Loss: 3.443, AUC: 0.896, AP: 0.901, Completeness: 0.390, Homogeneity: 0.391, NMI: 0.390
Epoch: 186, Loss: 3.470, AUC: 0.896, AP: 0.902, Completeness: 0.349, Homogeneity: 0.359, NMI: 0.354
Epoch: 187, Loss: 3.488, AUC: 0.896, AP: 0.902, Completeness: 0.386, Homogeneity: 0.387, NMI: 0.387
Epoch: 188, Loss: 3.486, AUC: 0.896, AP: 0.902, Completeness: 0.348, Homogeneity: 0.358, NMI: 0.353
Epoch: 189, Loss: 3.424, AUC: 0.896, AP: 0.902, Completeness: 0.361, Homogeneity: 0.362, NMI: 0.361
Epoch: 190, Loss: 3.471, AUC: 0.896, AP: 0.901, Completeness: 0.370, Homogeneity: 0.373, NMI: 0.371
Epoch: 191, Loss: 3.397, AUC: 0.895, AP: 0.900, Completeness: 0.341, Homogeneity: 0.353, NMI: 0.347
Epoch: 192, Loss: 3.355, AUC: 0.895, AP: 0.901, Completeness: 0.385, Homogeneity: 0.389, NMI: 0.387
Epoch: 193, Loss: 3.484, AUC: 0.896, AP: 0.902, Completeness: 0.337, Homogeneity: 0.348, NMI: 0.342
Epoch: 194, Loss: 3.314, AUC: 0.896, AP: 0.903, Completeness: 0.354, Homogeneity: 0.364, NMI: 0.359
Epoch: 195, Loss: 3.459, AUC: 0.895, AP: 0.904, Completeness: 0.338, Homogeneity: 0.349, NMI: 0.343
Epoch: 196, Loss: 3.409, AUC: 0.895, AP: 0.904, Completeness: 0.355, Homogeneity: 0.374, NMI: 0.364
Epoch: 197, Loss: 3.446, AUC: 0.894, AP: 0.904, Completeness: 0.361, Homogeneity: 0.381, NMI: 0.371
Epoch: 198, Loss: 3.322, AUC: 0.893, AP: 0.903, Completeness: 0.292, Homogeneity: 0.300, NMI: 0.296
Epoch: 199, Loss: 3.569, AUC: 0.893, AP: 0.904, Completeness: 0.290, Homogeneity: 0.300, NMI: 0.295
Epoch: 200, Loss: 3.459, AUC: 0.893, AP: 0.904, Completeness: 0.326, Homogeneity: 0.334, NMI: 0.330
[ ]:
@torch.no_grad()
def plot_points(colors):
    model.eval()
    z = model.encode(data.x, data.train_pos_edge_index)
    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())
    y = data.y.cpu().numpy()

    fig = plt.figure(1, figsize=(8, 8))
    fig.clf()
    for i in range(dataset.num_classes):
        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])
    plt.axis('off')
    plt.show()
[ ]:
#%%
colors = [
    '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'
]
plot_points(colors)
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:783: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.
  FutureWarning,
/usr/local/lib/python3.7/dist-packages/sklearn/manifold/_t_sne.py:793: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.
  FutureWarning,
../../_images/ipynbs_pyg_tutorial_project_Tutorial07_31_1.png