Tutorial 5: Aggregation¶
In this tutorial we will override the aggregation method of the GIN convolution module of Pytorch Geometric implementing the following methods:
Principal Neighborhood Aggregation (PNA)
Learning Aggregation Functions (LAF)
[ ]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
1.12.1+cu113
|████████████████████████████████| 7.9 MB 5.0 MB/s
|████████████████████████████████| 3.5 MB 4.9 MB/s
Building wheel for torch-geometric (setup.py) ... done
[ ]:
import torch
torch.manual_seed(42)
<torch._C.Generator at 0x7f1e2cb8ca30>
Message Passing Class¶
[ ]:
from torch_geometric.nn import MessagePassing
[ ]:
dir(MessagePassing)
['T_destination',
'__annotations__',
'__call__',
'__check_input__',
'__class__',
'__collect__',
'__delattr__',
'__dict__',
'__dir__',
'__doc__',
'__eq__',
'__format__',
'__ge__',
'__getattr__',
'__getattribute__',
'__gt__',
'__hash__',
'__init__',
'__init_subclass__',
'__le__',
'__lift__',
'__lt__',
'__module__',
'__ne__',
'__new__',
'__reduce__',
'__reduce_ex__',
'__repr__',
'__set_size__',
'__setattr__',
'__setstate__',
'__sizeof__',
'__str__',
'__subclasshook__',
'__weakref__',
'_apply',
'_call_impl',
'_get_backward_hooks',
'_get_name',
'_load_from_state_dict',
'_maybe_warn_non_full_backward_hook',
'_named_members',
'_register_load_state_dict_pre_hook',
'_register_state_dict_hook',
'_replicate_for_data_parallel',
'_save_to_state_dict',
'_slow_forward',
'_version',
'add_module',
'aggregate',
'apply',
'bfloat16',
'buffers',
'children',
'cpu',
'cuda',
'double',
'dump_patches',
'edge_update',
'edge_updater',
'eval',
'explain',
'explain_message',
'extra_repr',
'float',
'forward',
'get_buffer',
'get_extra_state',
'get_parameter',
'get_submodule',
'half',
'ipu',
'jittable',
'load_state_dict',
'message',
'message_and_aggregate',
'modules',
'named_buffers',
'named_children',
'named_modules',
'named_parameters',
'parameters',
'propagate',
'register_aggregate_forward_hook',
'register_aggregate_forward_pre_hook',
'register_backward_hook',
'register_buffer',
'register_edge_update_forward_hook',
'register_edge_update_forward_pre_hook',
'register_forward_hook',
'register_forward_pre_hook',
'register_full_backward_hook',
'register_load_state_dict_post_hook',
'register_message_and_aggregate_forward_hook',
'register_message_and_aggregate_forward_pre_hook',
'register_message_forward_hook',
'register_message_forward_pre_hook',
'register_module',
'register_parameter',
'register_propagate_forward_hook',
'register_propagate_forward_pre_hook',
'requires_grad_',
'set_extra_state',
'share_memory',
'special_args',
'state_dict',
'to',
'to_empty',
'train',
'type',
'update',
'xpu',
'zero_grad']
We are interested in the aggregate method, or, if you are using a sparse adjacency matrix, in the message_and_aggregate method. Convolutional classes in PyG extend MessagePassing, we construct our custom convoutional class extending GINConv.
Scatter operation in aggregate:
[ ]:
from torch.nn import Parameter, Module, Sigmoid
import torch
import torch_scatter
import torch.nn.functional as F
class AbstractLAFLayer(Module):
def __init__(self, **kwargs):
super(AbstractLAFLayer, self).__init__()
assert 'units' in kwargs or 'weights' in kwargs
if 'device' in kwargs.keys():
self.device = kwargs['device']
else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.ngpus = torch.cuda.device_count()
if 'kernel_initializer' in kwargs.keys():
assert kwargs['kernel_initializer'] in [
'random_normal',
'glorot_normal',
'he_normal',
'random_uniform',
'glorot_uniform',
'he_uniform']
self.kernel_initializer = kwargs['kernel_initializer']
else:
self.kernel_initializer = 'random_normal'
if 'weights' in kwargs.keys():
self.weights = Parameter(kwargs['weights'].to(self.device), \
requires_grad=True)
self.units = self.weights.shape[1]
else:
self.units = kwargs['units']
params = torch.empty(12, self.units, device=self.device)
if self.kernel_initializer == 'random_normal':
torch.nn.init.normal_(params)
elif self.kernel_initializer == 'glorot_normal':
torch.nn.init.xavier_normal_(params)
elif self.kernel_initializer == 'he_normal':
torch.nn.init.kaiming_normal_(params)
elif self.kernel_initializer == 'random_uniform':
torch.nn.init.uniform_(params)
elif self.kernel_initializer == 'glorot_uniform':
torch.nn.init.xavier_uniform_(params)
elif self.kernel_initializer == 'he_uniform':
torch.nn.init.kaiming_uniform_(params)
self.weights = Parameter(params, \
requires_grad=True)
e = torch.tensor([1,-1,1,-1], dtype=torch.float32, device=self.device)
self.e = Parameter(e, requires_grad=False)
num_idx = torch.tensor([1,1,0,0], dtype=torch.float32, device=self.device).\
view(1,1,-1,1)
self.num_idx = Parameter(num_idx, requires_grad=False)
den_idx = torch.tensor([0,0,1,1], dtype=torch.float32, device=self.device).\
view(1,1,-1,1)
self.den_idx = Parameter(den_idx, requires_grad=False)
class LAFLayer(AbstractLAFLayer):
def __init__(self, eps=1e-7, **kwargs):
super(LAFLayer, self).__init__(**kwargs)
self.eps = eps
def forward(self, data, index, dim=0, **kwargs):
eps = self.eps
sup = 1.0 - eps
e = self.e
x = torch.clamp(data, eps, sup)
x = torch.unsqueeze(x, -1)
e = e.view(1,1,-1)
exps = (1. - e)/2. + x*e
exps = torch.unsqueeze(exps, -1)
exps = torch.pow(exps, torch.relu(self.weights[0:4]))
scatter = torch_scatter.scatter_add(exps, index.view(-1), dim=dim)
scatter = torch.clamp(scatter, eps)
sqrt = torch.pow(scatter, torch.relu(self.weights[4:8]))
alpha_beta = self.weights[8:12].view(1,1,4,-1)
terms = sqrt * alpha_beta
num = torch.sum(terms * self.num_idx, dim=2)
den = torch.sum(terms * self.den_idx, dim=2)
multiplier = 2.0*torch.clamp(torch.sign(den), min=0.0) - 1.0
den = torch.where((den < eps) & (den > -eps), multiplier*eps, den)
res = num / den
return res
[ ]:
from torch_geometric.nn import GINConv
from torch.nn import Linear
LAF Aggregation Module¶

[ ]:
class GINLAFConv(GINConv):
def __init__(self, nn, units=1, node_dim=32, **kwargs):
super(GINLAFConv, self).__init__(nn, **kwargs)
self.laf = LAFLayer(units=units, kernel_initializer='random_uniform')
self.mlp = torch.nn.Linear(node_dim*units, node_dim)
self.dim = node_dim
self.units = units
def aggregate(self, inputs, index):
x = torch.sigmoid(inputs)
x = self.laf(x, index)
x = x.view((-1, self.dim * self.units))
x = self.mlp(x)
return x
PNA Aggregation¶

[ ]:
class GINPNAConv(GINConv):
def __init__(self, nn, node_dim=32, **kwargs):
super(GINPNAConv, self).__init__(nn, **kwargs)
self.mlp = torch.nn.Linear(node_dim*12, node_dim)
self.delta = 2.5749
def aggregate(self, inputs, index):
sums = torch_scatter.scatter_add(inputs, index, dim=0)
maxs = torch_scatter.scatter_max(inputs, index, dim=0)[0]
means = torch_scatter.scatter_mean(inputs, index, dim=0)
var = torch.relu(torch_scatter.scatter_mean(inputs ** 2, index, dim=0) - means ** 2)
aggrs = [sums, maxs, means, var]
c_idx = index.bincount().float().view(-1, 1)
l_idx = torch.log(c_idx + 1.)
amplification_scaler = [c_idx / self.delta * a for a in aggrs]
attenuation_scaler = [self.delta / c_idx * a for a in aggrs]
combinations = torch.cat(aggrs+ amplification_scaler+ attenuation_scaler, dim=1)
x = self.mlp(combinations)
return x
Test the new classes¶
[ ]:
from torch_geometric.nn import MessagePassing, SAGEConv, GINConv, global_add_pool
import torch_scatter
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
import os.path as osp
[ ]:
path = osp.join('./', 'data', 'TU')
dataset = TUDataset(path, name='MUTAG').shuffle()
test_dataset = dataset[:len(dataset) // 10]
train_dataset = dataset[len(dataset) // 10:]
test_loader = DataLoader(test_dataset, batch_size=128)
train_loader = DataLoader(train_dataset, batch_size=128)
Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Extracting data/TU/MUTAG/MUTAG.zip
Processing...
Done!
/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead
warnings.warn(out)
[ ]:
class LAFNet(torch.nn.Module):
def __init__(self):
super(LAFNet, self).__init__()
num_features = dataset.num_features
dim = 32
units = 3
nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
self.conv1 = GINLAFConv(nn1, units=units, node_dim=num_features)
self.bn1 = torch.nn.BatchNorm1d(dim)
nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv2 = GINLAFConv(nn2, units=units, node_dim=dim)
self.bn2 = torch.nn.BatchNorm1d(dim)
nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv3 = GINLAFConv(nn3, units=units, node_dim=dim)
self.bn3 = torch.nn.BatchNorm1d(dim)
nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv4 = GINLAFConv(nn4, units=units, node_dim=dim)
self.bn4 = torch.nn.BatchNorm1d(dim)
nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv5 = GINLAFConv(nn5, units=units, node_dim=dim)
self.bn5 = torch.nn.BatchNorm1d(dim)
self.fc1 = Linear(dim, dim)
self.fc2 = Linear(dim, dataset.num_classes)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = self.bn1(x)
x = F.relu(self.conv2(x, edge_index))
x = self.bn2(x)
x = F.relu(self.conv3(x, edge_index))
x = self.bn3(x)
x = F.relu(self.conv4(x, edge_index))
x = self.bn4(x)
x = F.relu(self.conv5(x, edge_index))
x = self.bn5(x)
x = global_add_pool(x, batch)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
[ ]:
class PNANet(torch.nn.Module):
def __init__(self):
super(PNANet, self).__init__()
num_features = dataset.num_features
dim = 32
nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
self.conv1 = GINPNAConv(nn1, node_dim=num_features)
self.bn1 = torch.nn.BatchNorm1d(dim)
nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv2 = GINPNAConv(nn2, node_dim=dim)
self.bn2 = torch.nn.BatchNorm1d(dim)
nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv3 = GINPNAConv(nn3, node_dim=dim)
self.bn3 = torch.nn.BatchNorm1d(dim)
nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv4 = GINPNAConv(nn4, node_dim=dim)
self.bn4 = torch.nn.BatchNorm1d(dim)
nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv5 = GINPNAConv(nn5, node_dim=dim)
self.bn5 = torch.nn.BatchNorm1d(dim)
self.fc1 = Linear(dim, dim)
self.fc2 = Linear(dim, dataset.num_classes)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = self.bn1(x)
x = F.relu(self.conv2(x, edge_index))
x = self.bn2(x)
x = F.relu(self.conv3(x, edge_index))
x = self.bn3(x)
x = F.relu(self.conv4(x, edge_index))
x = self.bn4(x)
x = F.relu(self.conv5(x, edge_index))
x = self.bn5(x)
x = global_add_pool(x, batch)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
[ ]:
class GINNet(torch.nn.Module):
def __init__(self):
super(GINNet, self).__init__()
num_features = dataset.num_features
dim = 32
nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
self.conv1 = GINConv(nn1)
self.bn1 = torch.nn.BatchNorm1d(dim)
nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv2 = GINConv(nn2)
self.bn2 = torch.nn.BatchNorm1d(dim)
nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv3 = GINConv(nn3)
self.bn3 = torch.nn.BatchNorm1d(dim)
nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv4 = GINConv(nn4)
self.bn4 = torch.nn.BatchNorm1d(dim)
nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
self.conv5 = GINConv(nn5)
self.bn5 = torch.nn.BatchNorm1d(dim)
self.fc1 = Linear(dim, dim)
self.fc2 = Linear(dim, dataset.num_classes)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = self.bn1(x)
x = F.relu(self.conv2(x, edge_index))
x = self.bn2(x)
x = F.relu(self.conv3(x, edge_index))
x = self.bn3(x)
x = F.relu(self.conv4(x, edge_index))
x = self.bn4(x)
x = F.relu(self.conv5(x, edge_index))
x = self.bn5(x)
x = global_add_pool(x, batch)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
[ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = "LAF"
if net == "LAF":
model = LAFNet().to(device)
elif net == "PNA":
model = PNANet().to(device)
elif net == "GIN":
GINNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def train(epoch):
model.train()
if epoch == 51:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.5 * param_group['lr']
loss_all = 0
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
output = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(output, data.y)
loss.backward()
loss_all += loss.item() * data.num_graphs
optimizer.step()
return loss_all / len(train_dataset)
def test(loader):
model.eval()
correct = 0
for data in loader:
data = data.to(device)
output = model(data.x, data.edge_index, data.batch)
pred = output.max(dim=1)[1]
correct += pred.eq(data.y).sum().item()
return correct / len(loader.dataset)
for epoch in range(1, 101):
train_loss = train(epoch)
train_acc = test(train_loader)
test_acc = test(test_loader)
print('Epoch: {:03d}, Train Loss: {:.7f}, '
'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
train_acc, test_acc))
Epoch: 001, Train Loss: 0.8650472, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 002, Train Loss: 0.7599028, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 003, Train Loss: 0.8972220, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 004, Train Loss: 0.6185434, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 005, Train Loss: 0.6005230, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 006, Train Loss: 0.5512175, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 007, Train Loss: 0.5332195, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 008, Train Loss: 0.5134736, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 009, Train Loss: 0.4718563, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 010, Train Loss: 0.4698687, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 011, Train Loss: 0.4464772, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 012, Train Loss: 0.4414581, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 013, Train Loss: 0.4507246, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 014, Train Loss: 0.4593955, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 015, Train Loss: 0.4188018, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 016, Train Loss: 0.3976869, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 017, Train Loss: 0.4080824, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 018, Train Loss: 0.4642429, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 019, Train Loss: 0.3612275, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 020, Train Loss: 0.3702769, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 021, Train Loss: 0.3751319, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 022, Train Loss: 0.3421200, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 023, Train Loss: 0.3866120, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 024, Train Loss: 0.3492658, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 025, Train Loss: 0.3558516, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 026, Train Loss: 0.3727173, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 027, Train Loss: 0.3154053, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 028, Train Loss: 0.3201577, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 029, Train Loss: 0.3272583, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 030, Train Loss: 0.3112883, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 031, Train Loss: 0.3407421, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 032, Train Loss: 0.2899052, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 033, Train Loss: 0.3580514, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 034, Train Loss: 0.2954516, Train Acc: 0.6764706, Test Acc: 0.6666667
Epoch: 035, Train Loss: 0.2975145, Train Acc: 0.6941176, Test Acc: 0.7777778
Epoch: 036, Train Loss: 0.3173143, Train Acc: 0.7235294, Test Acc: 0.8333333
Epoch: 037, Train Loss: 0.2602276, Train Acc: 0.7000000, Test Acc: 0.7777778
Epoch: 038, Train Loss: 0.2713226, Train Acc: 0.7117647, Test Acc: 0.8333333
Epoch: 039, Train Loss: 0.2706065, Train Acc: 0.6941176, Test Acc: 0.7777778
Epoch: 040, Train Loss: 0.2786444, Train Acc: 0.6941176, Test Acc: 0.6666667
Epoch: 041, Train Loss: 0.2833781, Train Acc: 0.7117647, Test Acc: 0.7777778
Epoch: 042, Train Loss: 0.2924134, Train Acc: 0.7117647, Test Acc: 0.7777778
Epoch: 043, Train Loss: 0.2812036, Train Acc: 0.8235294, Test Acc: 0.9444444
Epoch: 044, Train Loss: 0.2887369, Train Acc: 0.7941176, Test Acc: 0.8888889
Epoch: 045, Train Loss: 0.2367283, Train Acc: 0.7529412, Test Acc: 0.7777778
Epoch: 046, Train Loss: 0.2811927, Train Acc: 0.7941176, Test Acc: 0.8888889
Epoch: 047, Train Loss: 0.2571158, Train Acc: 0.7588235, Test Acc: 0.7777778
Epoch: 048, Train Loss: 0.2812370, Train Acc: 0.5588235, Test Acc: 0.6111111
Epoch: 049, Train Loss: 0.2664493, Train Acc: 0.8294118, Test Acc: 0.7222222
Epoch: 050, Train Loss: 0.2698024, Train Acc: 0.8117647, Test Acc: 0.8333333
Epoch: 051, Train Loss: 0.2950443, Train Acc: 0.7647059, Test Acc: 0.6666667
Epoch: 052, Train Loss: 0.2750369, Train Acc: 0.7352941, Test Acc: 0.6666667
Epoch: 053, Train Loss: 0.2846459, Train Acc: 0.7882353, Test Acc: 0.7777778
Epoch: 054, Train Loss: 0.2428172, Train Acc: 0.7588235, Test Acc: 0.8888889
Epoch: 055, Train Loss: 0.2569554, Train Acc: 0.7647059, Test Acc: 0.8888889
Epoch: 056, Train Loss: 0.2893244, Train Acc: 0.7647059, Test Acc: 0.8888889
Epoch: 057, Train Loss: 0.2695741, Train Acc: 0.8058824, Test Acc: 0.8333333
Epoch: 058, Train Loss: 0.2683432, Train Acc: 0.8235294, Test Acc: 0.8333333
Epoch: 059, Train Loss: 0.2483253, Train Acc: 0.8294118, Test Acc: 0.8333333
Epoch: 060, Train Loss: 0.2359067, Train Acc: 0.8117647, Test Acc: 0.8333333
Epoch: 061, Train Loss: 0.2581255, Train Acc: 0.8235294, Test Acc: 0.8333333
Epoch: 062, Train Loss: 0.2320385, Train Acc: 0.8882353, Test Acc: 0.8333333
Epoch: 063, Train Loss: 0.2304887, Train Acc: 0.8647059, Test Acc: 0.8333333
Epoch: 064, Train Loss: 0.2351827, Train Acc: 0.8176471, Test Acc: 0.8333333
Epoch: 065, Train Loss: 0.2371133, Train Acc: 0.7647059, Test Acc: 0.7777778
Epoch: 066, Train Loss: 0.2476480, Train Acc: 0.7705882, Test Acc: 0.7777778
Epoch: 067, Train Loss: 0.2557588, Train Acc: 0.9176471, Test Acc: 0.7777778
Epoch: 068, Train Loss: 0.2158999, Train Acc: 0.8352941, Test Acc: 0.7222222
Epoch: 069, Train Loss: 0.2353542, Train Acc: 0.7882353, Test Acc: 0.6111111
Epoch: 070, Train Loss: 0.2403484, Train Acc: 0.7705882, Test Acc: 0.5555556
Epoch: 071, Train Loss: 0.2292482, Train Acc: 0.7764706, Test Acc: 0.7222222
Epoch: 072, Train Loss: 0.2588242, Train Acc: 0.8000000, Test Acc: 0.7222222
Epoch: 073, Train Loss: 0.2330211, Train Acc: 0.8882353, Test Acc: 0.6666667
Epoch: 074, Train Loss: 0.2573530, Train Acc: 0.8235294, Test Acc: 0.7777778
Epoch: 075, Train Loss: 0.2250361, Train Acc: 0.8941176, Test Acc: 0.7777778
Epoch: 076, Train Loss: 0.2089147, Train Acc: 0.9058824, Test Acc: 0.7222222
Epoch: 077, Train Loss: 0.2145577, Train Acc: 0.9117647, Test Acc: 0.7777778
Epoch: 078, Train Loss: 0.2313216, Train Acc: 0.8705882, Test Acc: 0.7222222
Epoch: 079, Train Loss: 0.2348573, Train Acc: 0.8470588, Test Acc: 0.6666667
Epoch: 080, Train Loss: 0.2337190, Train Acc: 0.8176471, Test Acc: 0.6666667
Epoch: 081, Train Loss: 0.2247560, Train Acc: 0.7352941, Test Acc: 0.6111111
Epoch: 082, Train Loss: 0.2352007, Train Acc: 0.6352941, Test Acc: 0.5000000
Epoch: 083, Train Loss: 0.2404233, Train Acc: 0.7411765, Test Acc: 0.6111111
Epoch: 084, Train Loss: 0.2203369, Train Acc: 0.8000000, Test Acc: 0.6666667
Epoch: 085, Train Loss: 0.2096777, Train Acc: 0.8647059, Test Acc: 0.7777778
Epoch: 086, Train Loss: 0.2133037, Train Acc: 0.8411765, Test Acc: 0.8333333
Epoch: 087, Train Loss: 0.1921520, Train Acc: 0.8411765, Test Acc: 0.8333333
Epoch: 088, Train Loss: 0.2259413, Train Acc: 0.9117647, Test Acc: 0.7777778
Epoch: 089, Train Loss: 0.2021636, Train Acc: 0.9176471, Test Acc: 0.7777778
Epoch: 090, Train Loss: 0.1980333, Train Acc: 0.9294118, Test Acc: 0.7777778
Epoch: 091, Train Loss: 0.2226479, Train Acc: 0.8882353, Test Acc: 0.7222222
Epoch: 092, Train Loss: 0.2319808, Train Acc: 0.8294118, Test Acc: 0.6666667
Epoch: 093, Train Loss: 0.2365441, Train Acc: 0.8117647, Test Acc: 0.6666667
Epoch: 094, Train Loss: 0.2389078, Train Acc: 0.8764706, Test Acc: 0.6666667
Epoch: 095, Train Loss: 0.1904467, Train Acc: 0.9000000, Test Acc: 0.7777778
Epoch: 096, Train Loss: 0.2198207, Train Acc: 0.9235294, Test Acc: 0.7777778
Epoch: 097, Train Loss: 0.2253925, Train Acc: 0.9235294, Test Acc: 0.7777778
Epoch: 098, Train Loss: 0.1958107, Train Acc: 0.8352941, Test Acc: 0.7777778
Epoch: 099, Train Loss: 0.2277206, Train Acc: 0.8294118, Test Acc: 0.7777778
Epoch: 100, Train Loss: 0.2341601, Train Acc: 0.7058824, Test Acc: 0.7222222
[ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = "PNA"
if net == "LAF":
model = LAFNet().to(device)
elif net == "PNA":
model = PNANet().to(device)
elif net == "GIN":
GINNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 101):
train_loss = train(epoch)
train_acc = test(train_loader)
test_acc = test(test_loader)
print('Epoch: {:03d}, Train Loss: {:.7f}, '
'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
train_acc, test_acc))
Epoch: 001, Train Loss: 1.3497391, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 002, Train Loss: 0.8684199, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 003, Train Loss: 0.7279473, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 004, Train Loss: 0.7402998, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 005, Train Loss: 0.7657306, Train Acc: 0.3294118, Test Acc: 0.3888889
Epoch: 006, Train Loss: 0.7453549, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 007, Train Loss: 0.5307669, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 008, Train Loss: 0.4403997, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 009, Train Loss: 0.4544508, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 010, Train Loss: 0.4488042, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 011, Train Loss: 0.4297011, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 012, Train Loss: 0.3781979, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 013, Train Loss: 0.4004532, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 014, Train Loss: 0.3619624, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 015, Train Loss: 0.3303704, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 016, Train Loss: 0.3489703, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 017, Train Loss: 0.2879844, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 018, Train Loss: 0.2992957, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 019, Train Loss: 0.2941008, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 020, Train Loss: 0.2742822, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 021, Train Loss: 0.2828649, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 022, Train Loss: 0.2448842, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 023, Train Loss: 0.2426275, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 024, Train Loss: 0.2311836, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 025, Train Loss: 0.1974013, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 026, Train Loss: 0.1801775, Train Acc: 0.6647059, Test Acc: 0.6111111
Epoch: 027, Train Loss: 0.2027490, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 028, Train Loss: 0.1469845, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 029, Train Loss: 0.1597498, Train Acc: 0.6705882, Test Acc: 0.6111111
Epoch: 030, Train Loss: 0.1537053, Train Acc: 0.6823529, Test Acc: 0.6666667
Epoch: 031, Train Loss: 0.1487810, Train Acc: 0.7117647, Test Acc: 0.6666667
Epoch: 032, Train Loss: 0.1435155, Train Acc: 0.7176471, Test Acc: 0.7222222
Epoch: 033, Train Loss: 0.1291888, Train Acc: 0.7176471, Test Acc: 0.7222222
Epoch: 034, Train Loss: 0.1226045, Train Acc: 0.7470588, Test Acc: 0.7222222
Epoch: 035, Train Loss: 0.1137753, Train Acc: 0.7588235, Test Acc: 0.6666667
Epoch: 036, Train Loss: 0.1119650, Train Acc: 0.8058824, Test Acc: 0.7222222
Epoch: 037, Train Loss: 0.1085063, Train Acc: 0.8235294, Test Acc: 0.7222222
Epoch: 038, Train Loss: 0.1228803, Train Acc: 0.8352941, Test Acc: 0.7222222
Epoch: 039, Train Loss: 0.0784458, Train Acc: 0.8411765, Test Acc: 0.7777778
Epoch: 040, Train Loss: 0.0854303, Train Acc: 0.8764706, Test Acc: 0.8333333
Epoch: 041, Train Loss: 0.1073735, Train Acc: 0.9176471, Test Acc: 0.8333333
Epoch: 042, Train Loss: 0.0834263, Train Acc: 0.9176471, Test Acc: 0.8333333
Epoch: 043, Train Loss: 0.0607265, Train Acc: 0.9294118, Test Acc: 0.8333333
Epoch: 044, Train Loss: 0.0718378, Train Acc: 0.9352941, Test Acc: 0.8333333
Epoch: 045, Train Loss: 0.0689468, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 046, Train Loss: 0.0382091, Train Acc: 0.9529412, Test Acc: 0.7777778
Epoch: 047, Train Loss: 0.0879735, Train Acc: 0.9352941, Test Acc: 0.7777778
Epoch: 048, Train Loss: 0.0623109, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 049, Train Loss: 0.0908670, Train Acc: 0.9058824, Test Acc: 0.7777778
Epoch: 050, Train Loss: 0.0681597, Train Acc: 0.9000000, Test Acc: 0.8333333
Epoch: 051, Train Loss: 0.0693567, Train Acc: 0.9176471, Test Acc: 0.8333333
Epoch: 052, Train Loss: 0.0478872, Train Acc: 0.9352941, Test Acc: 0.8333333
Epoch: 053, Train Loss: 0.0506401, Train Acc: 0.9352941, Test Acc: 0.8333333
Epoch: 054, Train Loss: 0.0308294, Train Acc: 0.9294118, Test Acc: 0.8333333
Epoch: 055, Train Loss: 0.0326454, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 056, Train Loss: 0.0250327, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 057, Train Loss: 0.0349234, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 058, Train Loss: 0.0380354, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 059, Train Loss: 0.0238508, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 060, Train Loss: 0.0260360, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 061, Train Loss: 0.0156592, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 062, Train Loss: 0.0397532, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 063, Train Loss: 0.0147181, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 064, Train Loss: 0.0263763, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 065, Train Loss: 0.0219646, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 066, Train Loss: 0.0174770, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 067, Train Loss: 0.0233532, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 068, Train Loss: 0.0329869, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 069, Train Loss: 0.0267206, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 070, Train Loss: 0.0195115, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 071, Train Loss: 0.0263306, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 072, Train Loss: 0.0161402, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 073, Train Loss: 0.0138596, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 074, Train Loss: 0.0176732, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 075, Train Loss: 0.0140430, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 076, Train Loss: 0.0223834, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 077, Train Loss: 0.0151263, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 078, Train Loss: 0.0113194, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 079, Train Loss: 0.0178343, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 080, Train Loss: 0.0132977, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 081, Train Loss: 0.0099823, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 082, Train Loss: 0.0103535, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 083, Train Loss: 0.0049559, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 084, Train Loss: 0.0115411, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 085, Train Loss: 0.0132454, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 086, Train Loss: 0.0139688, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 087, Train Loss: 0.0082945, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 088, Train Loss: 0.0144088, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 089, Train Loss: 0.0116169, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 090, Train Loss: 0.0115055, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 091, Train Loss: 0.0044924, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 092, Train Loss: 0.0073951, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 093, Train Loss: 0.0098597, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 094, Train Loss: 0.0071243, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 095, Train Loss: 0.0084314, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 096, Train Loss: 0.0116200, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 097, Train Loss: 0.0109158, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 098, Train Loss: 0.0088956, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 099, Train Loss: 0.0098493, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 100, Train Loss: 0.0082795, Train Acc: 0.9529412, Test Acc: 0.8333333
[ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = "GIN"
if net == "LAF":
model = LAFNet().to(device)
elif net == "PNA":
model = PNANet().to(device)
elif net == "GIN":
GINNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 101):
train_loss = train(epoch)
train_acc = test(train_loader)
test_acc = test(test_loader)
print('Epoch: {:03d}, Train Loss: {:.7f}, '
'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
train_acc, test_acc))
Epoch: 001, Train Loss: 0.1006957, Train Acc: 0.8705882, Test Acc: 0.8888889
Epoch: 002, Train Loss: 0.5604971, Train Acc: 0.9235294, Test Acc: 0.7222222
Epoch: 003, Train Loss: 0.4205183, Train Acc: 0.8588235, Test Acc: 0.6111111
Epoch: 004, Train Loss: 0.1701126, Train Acc: 0.8176471, Test Acc: 0.6111111
Epoch: 005, Train Loss: 0.1697284, Train Acc: 0.8411765, Test Acc: 0.7777778
Epoch: 006, Train Loss: 0.1547725, Train Acc: 0.8176471, Test Acc: 0.7222222
Epoch: 007, Train Loss: 0.1122712, Train Acc: 0.7529412, Test Acc: 0.7777778
Epoch: 008, Train Loss: 0.1182288, Train Acc: 0.7470588, Test Acc: 0.7222222
Epoch: 009, Train Loss: 0.1329069, Train Acc: 0.9117647, Test Acc: 0.8333333
Epoch: 010, Train Loss: 0.1019645, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 011, Train Loss: 0.0771604, Train Acc: 0.9529412, Test Acc: 0.7777778
Epoch: 012, Train Loss: 0.0847688, Train Acc: 0.9176471, Test Acc: 0.7777778
Epoch: 013, Train Loss: 0.0684039, Train Acc: 0.9235294, Test Acc: 0.7777778
Epoch: 014, Train Loss: 0.0651711, Train Acc: 0.9470588, Test Acc: 0.7777778
Epoch: 015, Train Loss: 0.0518811, Train Acc: 0.9588235, Test Acc: 0.7777778
Epoch: 016, Train Loss: 0.0677568, Train Acc: 0.9705882, Test Acc: 0.7777778
Epoch: 017, Train Loss: 0.0393111, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 018, Train Loss: 0.0367973, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 019, Train Loss: 0.0366539, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 020, Train Loss: 0.0568547, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 021, Train Loss: 0.0447065, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 022, Train Loss: 0.0352459, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 023, Train Loss: 0.0249647, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 024, Train Loss: 0.0145648, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 025, Train Loss: 0.0205373, Train Acc: 0.9823529, Test Acc: 0.7777778
Epoch: 026, Train Loss: 0.0161799, Train Acc: 0.9647059, Test Acc: 0.7777778
Epoch: 027, Train Loss: 0.0125704, Train Acc: 0.9588235, Test Acc: 0.7777778
Epoch: 028, Train Loss: 0.0112206, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 029, Train Loss: 0.0095180, Train Acc: 0.9411765, Test Acc: 0.8333333
Epoch: 030, Train Loss: 0.0139799, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 031, Train Loss: 0.0133235, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 032, Train Loss: 0.0116212, Train Acc: 0.9588235, Test Acc: 0.8333333
Epoch: 033, Train Loss: 0.0074385, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 034, Train Loss: 0.0063465, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 035, Train Loss: 0.0093689, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 036, Train Loss: 0.0118155, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 037, Train Loss: 0.0166583, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 038, Train Loss: 0.0108432, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 039, Train Loss: 0.0092749, Train Acc: 0.9529412, Test Acc: 0.7777778
Epoch: 040, Train Loss: 0.0081560, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 041, Train Loss: 0.0145553, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 042, Train Loss: 0.0051442, Train Acc: 0.9764706, Test Acc: 0.7222222
Epoch: 043, Train Loss: 0.0128016, Train Acc: 0.9823529, Test Acc: 0.7222222
Epoch: 044, Train Loss: 0.0083365, Train Acc: 0.9823529, Test Acc: 0.6666667
Epoch: 045, Train Loss: 0.0449262, Train Acc: 0.9470588, Test Acc: 0.7222222
Epoch: 046, Train Loss: 0.1241174, Train Acc: 0.9352941, Test Acc: 0.7222222
Epoch: 047, Train Loss: 0.0577372, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 048, Train Loss: 0.0158565, Train Acc: 0.9588235, Test Acc: 0.7222222
Epoch: 049, Train Loss: 0.0264535, Train Acc: 0.9647059, Test Acc: 0.7222222
Epoch: 050, Train Loss: 0.0493911, Train Acc: 0.9470588, Test Acc: 0.7777778
Epoch: 051, Train Loss: 0.0307947, Train Acc: 0.9647059, Test Acc: 0.7777778
Epoch: 052, Train Loss: 0.0502689, Train Acc: 0.9705882, Test Acc: 0.7777778
Epoch: 053, Train Loss: 0.0220471, Train Acc: 0.9411765, Test Acc: 0.7777778
Epoch: 054, Train Loss: 0.0271277, Train Acc: 0.9294118, Test Acc: 0.8333333
Epoch: 055, Train Loss: 0.0193326, Train Acc: 0.9529412, Test Acc: 0.8888889
Epoch: 056, Train Loss: 0.0085988, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 057, Train Loss: 0.0200223, Train Acc: 0.9470588, Test Acc: 0.8333333
Epoch: 058, Train Loss: 0.0065113, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 059, Train Loss: 0.0118877, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 060, Train Loss: 0.0138910, Train Acc: 0.9470588, Test Acc: 0.8888889
Epoch: 061, Train Loss: 0.0099632, Train Acc: 0.9411765, Test Acc: 0.8888889
Epoch: 062, Train Loss: 0.0104697, Train Acc: 0.9411765, Test Acc: 0.8888889
Epoch: 063, Train Loss: 0.0117506, Train Acc: 0.9411765, Test Acc: 0.8888889
Epoch: 064, Train Loss: 0.0129451, Train Acc: 0.9470588, Test Acc: 0.8888889
Epoch: 065, Train Loss: 0.0049019, Train Acc: 0.9529412, Test Acc: 0.8888889
Epoch: 066, Train Loss: 0.0059774, Train Acc: 0.9529412, Test Acc: 0.8888889
Epoch: 067, Train Loss: 0.0029972, Train Acc: 0.9529412, Test Acc: 0.8333333
Epoch: 068, Train Loss: 0.0070204, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 069, Train Loss: 0.0058905, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 070, Train Loss: 0.0122656, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 071, Train Loss: 0.0080602, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 072, Train Loss: 0.0048456, Train Acc: 0.9588235, Test Acc: 0.8888889
Epoch: 073, Train Loss: 0.0024820, Train Acc: 0.9647059, Test Acc: 0.8888889
Epoch: 074, Train Loss: 0.0066221, Train Acc: 0.9647059, Test Acc: 0.8888889
Epoch: 075, Train Loss: 0.0054791, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 076, Train Loss: 0.0041069, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 077, Train Loss: 0.0039224, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 078, Train Loss: 0.0038528, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 079, Train Loss: 0.0026217, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 080, Train Loss: 0.0029335, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 081, Train Loss: 0.0046612, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 082, Train Loss: 0.0039182, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 083, Train Loss: 0.0048634, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 084, Train Loss: 0.0021048, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 085, Train Loss: 0.0055200, Train Acc: 0.9823529, Test Acc: 0.8333333
Epoch: 086, Train Loss: 0.0030238, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 087, Train Loss: 0.0048484, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 088, Train Loss: 0.0029367, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 089, Train Loss: 0.0014313, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 090, Train Loss: 0.0061374, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 091, Train Loss: 0.0055641, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 092, Train Loss: 0.0049570, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 093, Train Loss: 0.0008855, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 094, Train Loss: 0.0033897, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 095, Train Loss: 0.0020113, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 096, Train Loss: 0.0019971, Train Acc: 0.9764706, Test Acc: 0.8333333
Epoch: 097, Train Loss: 0.0073497, Train Acc: 0.9705882, Test Acc: 0.8333333
Epoch: 098, Train Loss: 0.0037542, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 099, Train Loss: 0.0011019, Train Acc: 0.9647059, Test Acc: 0.8333333
Epoch: 100, Train Loss: 0.0037228, Train Acc: 0.9647059, Test Acc: 0.8888889