Tutorial 12: GAE for link prediction¶
[1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.11.0
[2]:
import os.path as osp
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges
[3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"
[4]:
# load the Cora dataset
dataset = 'Cora'
path = osp.join('.', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
print(dataset.data)
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
Processing...
Done!
[5]:
# use train_test_split_edges to create neg and positive edges
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
print(data)
Data(x=[2708, 1433], val_pos_edge_index=[2, 263], test_pos_edge_index=[2, 527], train_pos_edge_index=[2, 8976], train_neg_adj_mask=[2708, 2708], val_neg_edge_index=[2, 263], test_neg_edge_index=[2, 527])
/home/user/anaconda3/envs/gnn/lib/python3.7/site-packages/torch_geometric/deprecation.py:12: UserWarning: 'train_test_split_edges' is deprecated, use 'transforms.RandomLinkSplit' instead
warnings.warn(out)
Simple autoencoder model¶
[6]:
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_features, 128)
self.conv2 = GCNConv(128, 64)
def encode(self):
x = self.conv1(data.x, data.train_pos_edge_index) # convolution 1
x = x.relu()
return self.conv2(x, data.train_pos_edge_index) # convolution 2
def decode(self, z, pos_edge_index, neg_edge_index): # only pos and neg edges
edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges
logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) # dot product
return logits
def decode_all(self, z):
prob_adj = z @ z.t() # get adj NxN
return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list
[7]:
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
[8]:
def get_link_labels(pos_edge_index, neg_edge_index):
# returns a tensor:
# [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the lenght of pos_edge_index
# and the number of zeros is equal to the length of neg_edge_index
E = pos_edge_index.size(1) + neg_edge_index.size(1)
link_labels = torch.zeros(E, dtype=torch.float, device=device)
link_labels[:pos_edge_index.size(1)] = 1.
return link_labels
def train():
model.train()
neg_edge_index = negative_sampling(
edge_index=data.train_pos_edge_index, #positive edges
num_nodes=data.num_nodes, # number of nodes
num_neg_samples=data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges
optimizer.zero_grad()
z = model.encode() #encode
link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) # decode
link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)
loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
loss.backward()
optimizer.step()
return loss
@torch.no_grad()
def test():
model.eval()
perfs = []
for prefix in ["val", "test"]:
pos_edge_index = data[f'{prefix}_pos_edge_index']
neg_edge_index = data[f'{prefix}_neg_edge_index']
z = model.encode() # encode train
link_logits = model.decode(z, pos_edge_index, neg_edge_index) # decode test or val
link_probs = link_logits.sigmoid() # apply sigmoid
link_labels = get_link_labels(pos_edge_index, neg_edge_index) # get link
perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score
return perfs
[9]:
best_val_perf = test_perf = 0
for epoch in range(1, 101):
train_loss = train()
val_perf, tmp_test_perf = test()
if val_perf > best_val_perf:
best_val_perf = val_perf
test_perf = tmp_test_perf
log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'
if epoch % 10 == 0:
print(log.format(epoch, train_loss, best_val_perf, test_perf))
Epoch: 010, Loss: 0.6749, Val: 0.7122, Test: 0.7107
Epoch: 020, Loss: 0.6134, Val: 0.7764, Test: 0.7935
Epoch: 030, Loss: 0.5288, Val: 0.8475, Test: 0.8281
Epoch: 040, Loss: 0.4940, Val: 0.9000, Test: 0.8686
Epoch: 050, Loss: 0.4691, Val: 0.9251, Test: 0.9012
Epoch: 060, Loss: 0.4580, Val: 0.9300, Test: 0.9081
Epoch: 070, Loss: 0.4496, Val: 0.9300, Test: 0.9081
Epoch: 080, Loss: 0.4438, Val: 0.9317, Test: 0.9187
Epoch: 090, Loss: 0.4413, Val: 0.9365, Test: 0.9237
Epoch: 100, Loss: 0.4332, Val: 0.9410, Test: 0.9258
[10]:
z = model.encode()
final_edge_index = model.decode_all(z)