10
10
[](https://app.labml.ai/run/34b1e2f6ed6f11ebb860997901a2d1e3)
11
11
"""
12
12
13
- from typing import Dict
14
-
15
- import numpy as np
16
13
import torch
17
14
from torch import nn
18
15
19
- from labml import lab , monit , tracker , experiment
20
- from labml .configs import BaseConfigs
21
- from labml .utils import download
22
- from labml_helpers .device import DeviceConfigs
16
+ from labml import experiment
17
+ from labml .configs import option
23
18
from labml_helpers .module import Module
19
+ from labml_nn .graphs .gat .experiment import Configs as GATConfigs
24
20
from labml_nn .graphs .gatv2 import GraphAttentionV2Layer
25
- from labml_nn .optimizers .configs import OptimizerConfigs
26
-
27
-
28
- class CoraDataset :
29
- """
30
- ## [Cora Dataset](https://linqs.soe.ucsc.edu/data)
31
-
32
- Cora dataset is a dataset of research papers.
33
- For each paper we are given a binary feature vector that indicates the presence of words.
34
- Each paper is classified into one of 7 classes.
35
- The dataset also has the citation network.
36
-
37
- The papers are the nodes of the graph and the edges are the citations.
38
-
39
- The task is to classify the edges to the 7 classes with feature vectors and
40
- citation network as input.
41
- """
42
- # Labels for each node
43
- labels : torch .Tensor
44
- # Set of class names and an unique integer index
45
- classes : Dict [str , int ]
46
- # Feature vectors for all nodes
47
- features : torch .Tensor
48
- # Adjacency matrix with the edge information.
49
- # `adj_mat[i][j]` is `True` if there is an edge from `i` to `j`.
50
- adj_mat : torch .Tensor
51
-
52
- @staticmethod
53
- def _download ():
54
- """
55
- Download the dataset
56
- """
57
- if not (lab .get_data_path () / 'cora' ).exists ():
58
- download .download_file ('https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz' ,
59
- lab .get_data_path () / 'cora.tgz' )
60
- download .extract_tar (lab .get_data_path () / 'cora.tgz' , lab .get_data_path ())
61
-
62
- def __init__ (self , include_edges : bool = True ):
63
- """
64
- Load the dataset
65
- """
66
-
67
- # Whether to include edges.
68
- # This is test how much accuracy is lost if we ignore the citation network.
69
- self .include_edges = include_edges
70
-
71
- # Download dataset
72
- self ._download ()
73
-
74
- # Read the paper ids, feature vectors, and labels
75
- with monit .section ('Read content file' ):
76
- content = np .genfromtxt (str (lab .get_data_path () / 'cora/cora.content' ), dtype = np .dtype (str ))
77
- # Load the citations, it's a list of pairs of integers.
78
- with monit .section ('Read citations file' ):
79
- citations = np .genfromtxt (str (lab .get_data_path () / 'cora/cora.cites' ), dtype = np .int32 )
80
-
81
- # Get the feature vectors
82
- features = torch .tensor (np .array (content [:, 1 :- 1 ], dtype = np .float32 ))
83
- # Normalize the feature vectors
84
- self .features = features / features .sum (dim = 1 , keepdim = True )
85
-
86
- # Get the class names and assign an unique integer to each of them
87
- self .classes = {s : i for i , s in enumerate (set (content [:, - 1 ]))}
88
- # Get the labels as those integers
89
- self .labels = torch .tensor ([self .classes [i ] for i in content [:, - 1 ]], dtype = torch .long )
90
-
91
- # Get the paper ids
92
- paper_ids = np .array (content [:, 0 ], dtype = np .int32 )
93
- # Map of paper id to index
94
- ids_to_idx = {id_ : i for i , id_ in enumerate (paper_ids )}
95
-
96
- # Empty adjacency matrix - an identity matrix
97
- self .adj_mat = torch .eye (len (self .labels ), dtype = torch .bool )
98
-
99
- # Mark the citations in the adjacency matrix
100
- if self .include_edges :
101
- for e in citations :
102
- # The pair of paper indexes
103
- e1 , e2 = ids_to_idx [e [0 ]], ids_to_idx [e [1 ]]
104
- # We build a symmetrical graph, where if paper $i$ referenced
105
- # paper $j$ we place an adge from $i$ to $j$ as well as an edge
106
- # from $j$ to $i$.
107
- self .adj_mat [e1 ][e2 ] = True
108
- self .adj_mat [e2 ][e1 ] = True
109
21
110
22
111
23
class GATv2 (Module ):
@@ -115,7 +27,8 @@ class GATv2(Module):
115
27
This graph attention network has two [graph attention layers](index.html).
116
28
"""
117
29
118
- def __init__ (self , in_features : int , n_hidden : int , n_classes : int , n_heads : int , dropout : float , share_weights : bool = True ):
30
+ def __init__ (self , in_features : int , n_hidden : int , n_classes : int , n_heads : int , dropout : float ,
31
+ share_weights : bool = True ):
119
32
"""
120
33
* `in_features` is the number of features per node
121
34
* `n_hidden` is the number of features in the first graph attention layer
@@ -127,11 +40,13 @@ def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int
127
40
super ().__init__ ()
128
41
129
42
# First graph attention layer where we concatenate the heads
130
- self .layer1 = GraphAttentionV2Layer (in_features , n_hidden , n_heads , is_concat = True , dropout = dropout , share_weights = share_weights )
43
+ self .layer1 = GraphAttentionV2Layer (in_features , n_hidden , n_heads , is_concat = True , dropout = dropout ,
44
+ share_weights = share_weights )
131
45
# Activation function after first graph attention layer
132
46
self .activation = nn .ELU ()
133
47
# Final graph attention layer where we average the heads
134
- self .output = GraphAttentionV2Layer (n_hidden , n_classes , 1 , is_concat = False , dropout = dropout , share_weights = share_weights )
48
+ self .output = GraphAttentionV2Layer (n_hidden , n_classes , 1 , is_concat = False , dropout = dropout ,
49
+ share_weights = share_weights )
135
50
# Dropout
136
51
self .dropout = nn .Dropout (dropout )
137
52
@@ -153,128 +68,26 @@ def __call__(self, x: torch.Tensor, adj_mat: torch.Tensor):
153
68
return self .output (x , adj_mat )
154
69
155
70
156
- def accuracy (output : torch .Tensor , labels : torch .Tensor ):
157
- """
158
- A simple function to calculate the accuracy
159
- """
160
- return output .argmax (dim = - 1 ).eq (labels ).sum ().item () / len (labels )
161
-
162
-
163
- class Configs (BaseConfigs ):
71
+ class Configs (GATConfigs ):
164
72
"""
165
73
## Configurations
166
- """
167
-
168
- # Model
169
- model : GATv2
170
- # Number of nodes to train on
171
- training_samples : int = 500
172
- # Number of features per node in the input
173
- in_features : int
174
- # Number of features in the first graph attention layer
175
- n_hidden : int = 64
176
- # Number of heads
177
- n_heads : int = 8
178
- # Number of classes for classification
179
- n_classes : int
180
- # Dropout probability
181
- dropout : float = 0.7
182
- # Whether to include the citation network
183
- include_edges : bool = True
184
- # Dataset
185
- dataset : CoraDataset
186
- # Number of training iterations
187
- epochs : int = 1_000
188
- # Loss function
189
- loss_func = nn .CrossEntropyLoss ()
190
- # Device to train on
191
- #
192
- # This creates configs for device, so that
193
- # we can change the device by passing a config value
194
- device : torch .device = DeviceConfigs ()
195
- # Optimizer
196
- optimizer : torch .optim .Adam
197
-
198
- def initialize (self ):
199
- """
200
- Initialize
201
- """
202
- # Create the dataset
203
- self .dataset = CoraDataset (self .include_edges )
204
- # Get the number of classes
205
- self .n_classes = len (self .dataset .classes )
206
- # Number of features in the input
207
- self .in_features = self .dataset .features .shape [1 ]
208
- # Create the model
209
- self .model = GATv2 (self .in_features , self .n_hidden , self .n_classes , self .n_heads , self .dropout )
210
- # Move the model to the device
211
- self .model .to (self .device )
212
- # Configurable optimizer, so that we can set the configurations
213
- # such as learning rate by passing the dictionary later.
214
- optimizer_conf = OptimizerConfigs ()
215
- optimizer_conf .parameters = self .model .parameters ()
216
- self .optimizer = optimizer_conf
217
-
218
- def run (self ):
219
- """
220
- ### Training loop
221
-
222
- We do full batch training since the dataset is small.
223
- If we were to sample and train we will have to sample a set of
224
- nodes for each training step along with the edges that span
225
- across those selected nodes.
226
- """
227
- # Move the feature vectors to the device
228
- features = self .dataset .features .to (self .device )
229
- # Move the labels to the device
230
- labels = self .dataset .labels .to (self .device )
231
- # Move the adjacency matrix to the device
232
- edges_adj = self .dataset .adj_mat .to (self .device )
233
- # Add an empty third dimension for the heads
234
- edges_adj = edges_adj .unsqueeze (- 1 )
235
74
236
- # Random indexes
237
- idx_rand = torch .randperm (len (labels ))
238
- # Nodes for training
239
- idx_train = idx_rand [:self .training_samples ]
240
- # Nodes for validation
241
- idx_valid = idx_rand [self .training_samples :]
242
-
243
- # Training loop
244
- for epoch in monit .loop (self .epochs ):
245
- # Set the model to training mode
246
- self .model .train ()
247
- # Make all the gradients zero
248
- self .optimizer .zero_grad ()
249
- # Evaluate the model
250
- output = self .model (features , edges_adj )
251
- # Get the loss for training nodes
252
- loss = self .loss_func (output [idx_train ], labels [idx_train ])
253
- # Calculate gradients
254
- loss .backward ()
255
- # Take optimization step
256
- self .optimizer .step ()
257
- # Log the loss
258
- tracker .add ('loss.train' , loss )
259
- # Log the accuracy
260
- tracker .add ('accuracy.train' , accuracy (output [idx_train ], labels [idx_train ]))
75
+ Since the experiment is same as [GAT experiment](../gat/experiment.html) but with
76
+ [GATv2 mode](index.html) we extend the same configs and change the model
77
+ """
261
78
262
- # Set mode to evaluation mode for validation
263
- self .model .eval ()
79
+ # Whether to share weights for source and target nodes of edges
80
+ share_weights : bool = True
81
+ # Set the model
82
+ model : GATv2 = 'gat_v2_model'
264
83
265
- # No need to compute gradients
266
- with torch .no_grad ():
267
- # Evaluate the model again
268
- output = self .model (features , edges_adj )
269
- # Calculate the loss for validation nodes
270
- loss = self .loss_func (output [idx_valid ], labels [idx_valid ])
271
- # Log the loss
272
- tracker .add ('loss.valid' , loss )
273
- # Log the accuracy
274
- tracker .add ('accuracy.valid' , accuracy (output [idx_valid ], labels [idx_valid ]))
275
84
276
- # Save logs
277
- tracker .save ()
85
+ @option (Configs .model )
86
+ def gat_v2_model (c : Configs ):
87
+ """
88
+ Create GAT model
89
+ """
90
+ return GATv2 (c .in_features , c .n_hidden , c .n_classes , c .n_heads , c .dropout , c .share_weights ).to (c .device )
278
91
279
92
280
93
def main ():
@@ -288,9 +101,9 @@ def main():
288
101
'optimizer.optimizer' : 'Adam' ,
289
102
'optimizer.learning_rate' : 5e-3 ,
290
103
'optimizer.weight_decay' : 5e-4 ,
104
+
105
+ 'dropout' : 0.7 ,
291
106
})
292
- # Initialize
293
- conf .initialize ()
294
107
295
108
# Start and watch the experiment
296
109
with experiment .start ():
0 commit comments