-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
mnist_nn_conv.py
113 lines (88 loc) · 3.31 KB
/
mnist_nn_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os.path as osp
import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Sequential
import torch_geometric.transforms as T
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
NNConv,
global_mean_pool,
graclus,
max_pool,
max_pool_x,
)
from torch_geometric.typing import WITH_TORCH_CLUSTER
from torch_geometric.utils import normalized_cut
if not WITH_TORCH_CLUSTER:
quit("This example requires 'torch-cluster'")
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')
transform = T.Cartesian(cat=False)
train_dataset = MNISTSuperpixels(path, True, transform=transform)
test_dataset = MNISTSuperpixels(path, False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
d = train_dataset
def normalized_cut_2d(edge_index, pos):
row, col = edge_index
edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)
return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
nn1 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, d.num_features * 32),
)
self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')
nn2 = Sequential(
Linear(2, 25),
ReLU(),
Linear(25, 32 * 64),
)
self.conv2 = NNConv(32, 64, nn2, aggr='mean')
self.fc1 = torch.nn.Linear(64, 128)
self.fc2 = torch.nn.Linear(128, d.num_classes)
def forward(self, data):
data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
weight = normalized_cut_2d(data.edge_index, data.pos)
cluster = graclus(data.edge_index, weight, data.x.size(0))
data.edge_attr = None
data = max_pool(cluster, data, transform=transform)
data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))
weight = normalized_cut_2d(data.edge_index, data.pos)
cluster = graclus(data.edge_index, weight, data.x.size(0))
x, batch = max_pool_x(cluster, data.x, data.batch)
x = global_mean_pool(x, batch)
x = F.elu(self.fc1(x))
x = F.dropout(x, training=self.training)
return F.log_softmax(self.fc2(x), dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(epoch):
model.train()
if epoch == 16:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.001
if epoch == 26:
for param_group in optimizer.param_groups:
param_group['lr'] = 0.0001
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
F.nll_loss(model(data), data.y).backward()
optimizer.step()
def test():
model.eval()
correct = 0
for data in test_loader:
data = data.to(device)
pred = model(data).max(1)[1]
correct += pred.eq(data.y).sum().item()
return correct / len(test_dataset)
for epoch in range(1, 31):
train(epoch)
test_acc = test()
print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}')