Skip to content

Commit c39e805

Browse files
committed
refracter gat experiment
1 parent 79505c4 commit c39e805

File tree

2 files changed

+58
-235
lines changed

2 files changed

+58
-235
lines changed

labml_nn/graphs/gat/experiment.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch import nn
1818

1919
from labml import lab, monit, tracker, experiment
20-
from labml.configs import BaseConfigs
20+
from labml.configs import BaseConfigs, option, calculate
2121
from labml.utils import download
2222
from labml_helpers.device import DeviceConfigs
2323
from labml_helpers.module import Module
@@ -194,26 +194,6 @@ class Configs(BaseConfigs):
194194
# Optimizer
195195
optimizer: torch.optim.Adam
196196

197-
def initialize(self):
198-
"""
199-
Initialize
200-
"""
201-
# Create the dataset
202-
self.dataset = CoraDataset(self.include_edges)
203-
# Get the number of classes
204-
self.n_classes = len(self.dataset.classes)
205-
# Number of features in the input
206-
self.in_features = self.dataset.features.shape[1]
207-
# Create the model
208-
self.model = GAT(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout)
209-
# Move the model to the device
210-
self.model.to(self.device)
211-
# Configurable optimizer, so that we can set the configurations
212-
# such as learning rate by passing the dictionary later.
213-
optimizer_conf = OptimizerConfigs()
214-
optimizer_conf.parameters = self.model.parameters()
215-
self.optimizer = optimizer_conf
216-
217197
def run(self):
218198
"""
219199
### Training loop
@@ -276,6 +256,38 @@ def run(self):
276256
tracker.save()
277257

278258

259+
@option(Configs.dataset)
260+
def cora_dataset(c: Configs):
261+
"""
262+
Create Cora dataset
263+
"""
264+
return CoraDataset(c.include_edges)
265+
266+
267+
# Get the number of classes
268+
calculate(Configs.n_classes, lambda c: len(c.dataset.classes))
269+
# Number of features in the input
270+
calculate(Configs.in_features, lambda c: c.dataset.features.shape[1])
271+
272+
273+
@option(Configs.model)
274+
def gat_model(c: Configs):
275+
"""
276+
Create GAT model
277+
"""
278+
return GAT(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout).to(c.device)
279+
280+
281+
@option(Configs.optimizer)
282+
def _optimizer(c: Configs):
283+
"""
284+
Create configurable optimizer
285+
"""
286+
opt_conf = OptimizerConfigs()
287+
opt_conf.parameters = c.model.parameters()
288+
return opt_conf
289+
290+
279291
def main():
280292
# Create configurations
281293
conf = Configs()
@@ -288,8 +300,6 @@ def main():
288300
'optimizer.learning_rate': 5e-3,
289301
'optimizer.weight_decay': 5e-4,
290302
})
291-
# Initialize
292-
conf.initialize()
293303

294304
# Start and watch the experiment
295305
with experiment.start():

labml_nn/graphs/gatv2/experiment.py

Lines changed: 25 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -10,102 +10,14 @@
1010
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
1111
"""
1212

13-
from typing import Dict
14-
15-
import numpy as np
1613
import torch
1714
from torch import nn
1815

19-
from labml import lab, monit, tracker, experiment
20-
from labml.configs import BaseConfigs
21-
from labml.utils import download
22-
from labml_helpers.device import DeviceConfigs
16+
from labml import experiment
17+
from labml.configs import option
2318
from labml_helpers.module import Module
19+
from labml_nn.graphs.gat.experiment import Configs as GATConfigs
2420
from labml_nn.graphs.gatv2 import GraphAttentionV2Layer
25-
from labml_nn.optimizers.configs import OptimizerConfigs
26-
27-
28-
class CoraDataset:
29-
"""
30-
## [Cora Dataset](https://linqs.soe.ucsc.edu/data)
31-
32-
Cora dataset is a dataset of research papers.
33-
For each paper we are given a binary feature vector that indicates the presence of words.
34-
Each paper is classified into one of 7 classes.
35-
The dataset also has the citation network.
36-
37-
The papers are the nodes of the graph and the edges are the citations.
38-
39-
The task is to classify the edges to the 7 classes with feature vectors and
40-
citation network as input.
41-
"""
42-
# Labels for each node
43-
labels: torch.Tensor
44-
# Set of class names and an unique integer index
45-
classes: Dict[str, int]
46-
# Feature vectors for all nodes
47-
features: torch.Tensor
48-
# Adjacency matrix with the edge information.
49-
# `adj_mat[i][j]` is `True` if there is an edge from `i` to `j`.
50-
adj_mat: torch.Tensor
51-
52-
@staticmethod
53-
def _download():
54-
"""
55-
Download the dataset
56-
"""
57-
if not (lab.get_data_path() / 'cora').exists():
58-
download.download_file('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz',
59-
lab.get_data_path() / 'cora.tgz')
60-
download.extract_tar(lab.get_data_path() / 'cora.tgz', lab.get_data_path())
61-
62-
def __init__(self, include_edges: bool = True):
63-
"""
64-
Load the dataset
65-
"""
66-
67-
# Whether to include edges.
68-
# This is test how much accuracy is lost if we ignore the citation network.
69-
self.include_edges = include_edges
70-
71-
# Download dataset
72-
self._download()
73-
74-
# Read the paper ids, feature vectors, and labels
75-
with monit.section('Read content file'):
76-
content = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.content'), dtype=np.dtype(str))
77-
# Load the citations, it's a list of pairs of integers.
78-
with monit.section('Read citations file'):
79-
citations = np.genfromtxt(str(lab.get_data_path() / 'cora/cora.cites'), dtype=np.int32)
80-
81-
# Get the feature vectors
82-
features = torch.tensor(np.array(content[:, 1:-1], dtype=np.float32))
83-
# Normalize the feature vectors
84-
self.features = features / features.sum(dim=1, keepdim=True)
85-
86-
# Get the class names and assign an unique integer to each of them
87-
self.classes = {s: i for i, s in enumerate(set(content[:, -1]))}
88-
# Get the labels as those integers
89-
self.labels = torch.tensor([self.classes[i] for i in content[:, -1]], dtype=torch.long)
90-
91-
# Get the paper ids
92-
paper_ids = np.array(content[:, 0], dtype=np.int32)
93-
# Map of paper id to index
94-
ids_to_idx = {id_: i for i, id_ in enumerate(paper_ids)}
95-
96-
# Empty adjacency matrix - an identity matrix
97-
self.adj_mat = torch.eye(len(self.labels), dtype=torch.bool)
98-
99-
# Mark the citations in the adjacency matrix
100-
if self.include_edges:
101-
for e in citations:
102-
# The pair of paper indexes
103-
e1, e2 = ids_to_idx[e[0]], ids_to_idx[e[1]]
104-
# We build a symmetrical graph, where if paper $i$ referenced
105-
# paper $j$ we place an adge from $i$ to $j$ as well as an edge
106-
# from $j$ to $i$.
107-
self.adj_mat[e1][e2] = True
108-
self.adj_mat[e2][e1] = True
10921

11022

11123
class GATv2(Module):
@@ -115,7 +27,8 @@ class GATv2(Module):
11527
This graph attention network has two [graph attention layers](index.html).
11628
"""
11729

118-
def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float, share_weights: bool = True):
30+
def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float,
31+
share_weights: bool = True):
11932
"""
12033
* `in_features` is the number of features per node
12134
* `n_hidden` is the number of features in the first graph attention layer
@@ -127,11 +40,13 @@ def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int
12740
super().__init__()
12841

12942
# First graph attention layer where we concatenate the heads
130-
self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout, share_weights=share_weights)
43+
self.layer1 = GraphAttentionV2Layer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout,
44+
share_weights=share_weights)
13145
# Activation function after first graph attention layer
13246
self.activation = nn.ELU()
13347
# Final graph attention layer where we average the heads
134-
self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout, share_weights=share_weights)
48+
self.output = GraphAttentionV2Layer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout,
49+
share_weights=share_weights)
13550
# Dropout
13651
self.dropout = nn.Dropout(dropout)
13752

@@ -153,128 +68,26 @@ def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor):
15368
return self.output(x, adj_mat)
15469

15570

156-
def accuracy(output: torch.Tensor, labels: torch.Tensor):
157-
"""
158-
A simple function to calculate the accuracy
159-
"""
160-
return output.argmax(dim=-1).eq(labels).sum().item() / len(labels)
161-
162-
163-
class Configs(BaseConfigs):
71+
class Configs(GATConfigs):
16472
"""
16573
## Configurations
166-
"""
167-
168-
# Model
169-
model: GATv2
170-
# Number of nodes to train on
171-
training_samples: int = 500
172-
# Number of features per node in the input
173-
in_features: int
174-
# Number of features in the first graph attention layer
175-
n_hidden: int = 64
176-
# Number of heads
177-
n_heads: int = 8
178-
# Number of classes for classification
179-
n_classes: int
180-
# Dropout probability
181-
dropout: float = 0.7
182-
# Whether to include the citation network
183-
include_edges: bool = True
184-
# Dataset
185-
dataset: CoraDataset
186-
# Number of training iterations
187-
epochs: int = 1_000
188-
# Loss function
189-
loss_func = nn.CrossEntropyLoss()
190-
# Device to train on
191-
#
192-
# This creates configs for device, so that
193-
# we can change the device by passing a config value
194-
device: torch.device = DeviceConfigs()
195-
# Optimizer
196-
optimizer: torch.optim.Adam
197-
198-
def initialize(self):
199-
"""
200-
Initialize
201-
"""
202-
# Create the dataset
203-
self.dataset = CoraDataset(self.include_edges)
204-
# Get the number of classes
205-
self.n_classes = len(self.dataset.classes)
206-
# Number of features in the input
207-
self.in_features = self.dataset.features.shape[1]
208-
# Create the model
209-
self.model = GATv2(self.in_features, self.n_hidden, self.n_classes, self.n_heads, self.dropout)
210-
# Move the model to the device
211-
self.model.to(self.device)
212-
# Configurable optimizer, so that we can set the configurations
213-
# such as learning rate by passing the dictionary later.
214-
optimizer_conf = OptimizerConfigs()
215-
optimizer_conf.parameters = self.model.parameters()
216-
self.optimizer = optimizer_conf
217-
218-
def run(self):
219-
"""
220-
### Training loop
221-
222-
We do full batch training since the dataset is small.
223-
If we were to sample and train we will have to sample a set of
224-
nodes for each training step along with the edges that span
225-
across those selected nodes.
226-
"""
227-
# Move the feature vectors to the device
228-
features = self.dataset.features.to(self.device)
229-
# Move the labels to the device
230-
labels = self.dataset.labels.to(self.device)
231-
# Move the adjacency matrix to the device
232-
edges_adj = self.dataset.adj_mat.to(self.device)
233-
# Add an empty third dimension for the heads
234-
edges_adj = edges_adj.unsqueeze(-1)
23574
236-
# Random indexes
237-
idx_rand = torch.randperm(len(labels))
238-
# Nodes for training
239-
idx_train = idx_rand[:self.training_samples]
240-
# Nodes for validation
241-
idx_valid = idx_rand[self.training_samples:]
242-
243-
# Training loop
244-
for epoch in monit.loop(self.epochs):
245-
# Set the model to training mode
246-
self.model.train()
247-
# Make all the gradients zero
248-
self.optimizer.zero_grad()
249-
# Evaluate the model
250-
output = self.model(features, edges_adj)
251-
# Get the loss for training nodes
252-
loss = self.loss_func(output[idx_train], labels[idx_train])
253-
# Calculate gradients
254-
loss.backward()
255-
# Take optimization step
256-
self.optimizer.step()
257-
# Log the loss
258-
tracker.add('loss.train', loss)
259-
# Log the accuracy
260-
tracker.add('accuracy.train', accuracy(output[idx_train], labels[idx_train]))
75+
Since the experiment is same as [GAT experiment](../gat/experiment.html) but with
76+
[GATv2 mode](index.html) we extend the same configs and change the model
77+
"""
26178

262-
# Set mode to evaluation mode for validation
263-
self.model.eval()
79+
# Whether to share weights for source and target nodes of edges
80+
share_weights: bool = True
81+
# Set the model
82+
model: GATv2 = 'gat_v2_model'
26483

265-
# No need to compute gradients
266-
with torch.no_grad():
267-
# Evaluate the model again
268-
output = self.model(features, edges_adj)
269-
# Calculate the loss for validation nodes
270-
loss = self.loss_func(output[idx_valid], labels[idx_valid])
271-
# Log the loss
272-
tracker.add('loss.valid', loss)
273-
# Log the accuracy
274-
tracker.add('accuracy.valid', accuracy(output[idx_valid], labels[idx_valid]))
27584

276-
# Save logs
277-
tracker.save()
85+
@option(Configs.model)
86+
def gat_v2_model(c: Configs):
87+
"""
88+
Create GAT model
89+
"""
90+
return GATv2(c.in_features, c.n_hidden, c.n_classes, c.n_heads, c.dropout, c.share_weights).to(c.device)
27891

27992

28093
def main():
@@ -288,9 +101,9 @@ def main():
288101
'optimizer.optimizer': 'Adam',
289102
'optimizer.learning_rate': 5e-3,
290103
'optimizer.weight_decay': 5e-4,
104+
105+
'dropout': 0.7,
291106
})
292-
# Initialize
293-
conf.initialize()
294107

295108
# Start and watch the experiment
296109
with experiment.start():

0 commit comments

Comments
 (0)