Skip to content

Commit ff83c11

Browse files
Merge pull request #79 from joshchang1112:master
PiperOrigin-RevId: 351154923
2 parents 2f72a16 + 449dcea commit ff83c11

File tree

4 files changed

+62
-36
lines changed

4 files changed

+62
-36
lines changed

research/gnn-survey/README.md

+21-13
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
This repository contains a tensorflow 2.0 implementation of GNN models for node
44
classification.
55

6-
Update: ** Sparse version of GCN and GAT are available **
6+
Update: ** Sparse version of GCN, GAT, GIN are available **
77

88
## Code organization
99

10-
* `download_dataset.sh`: Download graph dataset. Now only `cora` is available.
10+
* `download_dataset.sh`: Download graph dataset. Now `cora` and `citeseer` are
11+
available.
1112

1213
* `train.py`: Trains a model with FLAGS parameters. `python train.py --help`
1314
for more information.
1415

15-
* `models.py`: Gnn models implementation. Now `gcn` and `gat` are available.
16+
* `models.py`: Gnn models implementation. Now `gcn`, `gat` and `gin` are
17+
available.
1618

1719
* `layers.py`: Single gnn layers implementation.
1820

@@ -25,7 +27,7 @@ Update: ** Sparse version of GCN and GAT are available **
2527
2. Download the dataset.
2628

2729
```
28-
bash download_dataset.sh
30+
bash download_dataset.sh <DATASET>
2931
```
3032

3133
1. Train GAT on cora with default parameters.
@@ -39,7 +41,8 @@ python train.py --save_dir=$SAVE_DIR
3941

4042
## Training Results
4143

42-
* Better GAT results on cora (84.7% average test accuracy)[^GAT]:
44+
* Better GAT results on cora (84.7% average test accuracy,
45+
[[2]](#references)):
4346

4447
```
4548
python train.py \
@@ -56,7 +59,8 @@ python train.py \
5659
--sparse_features=True
5760
```
5861

59-
* Reproduce gcn results on cora (81.5% average test accuracy)[^GCN]:
62+
* Reproduce gcn results on cora (81.5% average test accuracy,
63+
[[1]](#references)):
6064

6165
```
6266
python train.py \
@@ -72,7 +76,8 @@ python train.py \
7276
--sparse_features=True
7377
```
7478

75-
* Better gcn results on cora (82.5% average test accuracy):
79+
* Better gcn results on cora (82.5% average test accuracy,
80+
[[1]](#references)):
7681

7782
```
7883
python train.py \
@@ -88,7 +93,7 @@ python train.py \
8893
--sparse_features=True
8994
```
9095

91-
* GIN results on cora (81.7% average test accuracy):
96+
* GIN results on cora (81.7% average test accuracy, [[3]](#references)):
9297

9398
```
9499
python train.py \
@@ -107,8 +112,11 @@ python train.py \
107112

108113
## References
109114

110-
[^GCN]: Thomas N. Kipf, Max Welling "Semi-Supervised Classification with Graph
111-
Convolutional Networks"
112-
[GCN original github](https://github.com/tkipf/gcn/tree/master/gcn)
113-
[^GAT]: Petar Veličković, Guillem Cucurull, et al. "Graph Attention Networks"
114-
[GAT original github](https://github.com/PetarV-/GAT)
115+
[[1] T. Kipf and M. Welling. "Semi-Supervised Classification with Graph
116+
Convolutional Networks" ICLR 2017](https://arxiv.org/pdf/1609.02907.pdf)
117+
118+
[[2] P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò and Y. Bengio.
119+
"Graph Attention Networks" ICLR 2018](https://arxiv.org/pdf/1710.10903.pdf)
120+
121+
[[3] K. Xu, W. Hu, J. Leskovec and S. Jegelka. "How Powerful are Graph Neural
122+
Networks?" ICLR 2019](https://arxiv.org/pdf/1810.00826.pdf)

research/gnn-survey/download_dataset.sh

+5-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414
# TODO(ppham27): Consider consolidating with
1515
# examples/preprocess/cora/prep_data.sh.
16-
# URL for downloading Cora dataset.
17-
URL=https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
16+
# URL for downloading dataset (Available dataset: cora, citeseer).
17+
DATASET=$1
18+
URL=https://linqs-data.soe.ucsc.edu/public/lbc/${DATASET}.tgz
1819

1920
# Target folder to store and process data.
2021
DATA_DIR=data
@@ -32,6 +33,6 @@ function download () {
3233
fi
3334
}
3435

35-
# Download and unzip the dataset. Data will be at '${DATA_DIR}/cora/' folder.
36+
# Download and unzip the dataset. Data will be at '${DATA_DIR}/${DATASET}/' folder.
3637
download ${URL} ${DATA_DIR}
37-
tar -C ${DATA_DIR} -xvzf ${DATA_DIR}/cora.tgz
38+
tar -C ${DATA_DIR} -xvzf ${DATA_DIR}/${DATASET}.tgz

research/gnn-survey/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
from utils import load_dataset, build_model, cal_acc # pylint: disable=g-multiple-import
2222

23-
flags.DEFINE_enum('dataset', 'cora', ['cora'],
24-
'The input dataset. Avaliable dataset now: cora')
23+
flags.DEFINE_enum('dataset', 'cora', ['cora', 'citeseer'],
24+
'The input dataset. Avaliable dataset now: cora, citeseer')
2525
flags.DEFINE_enum('model', 'gat', ['gcn', 'gat', 'gin'],
2626
'GNN model. Available model now: gcn, gat')
2727
flags.DEFINE_float('dropout_rate', 0.6, 'Dropout probability')

research/gnn-survey/utils.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,34 @@ def cal_acc(labels, logits):
6565
return acc.numpy().item()
6666

6767

68-
def encode_onehot(labels):
68+
def encode_onehot(dataset, labels):
6969
"""Provides a mapping from string labels to integer indices."""
7070
label_index = {
71-
'Case_Based': 0,
72-
'Genetic_Algorithms': 1,
73-
'Neural_Networks': 2,
74-
'Probabilistic_Methods': 3,
75-
'Reinforcement_Learning': 4,
76-
'Rule_Learning': 5,
77-
'Theory': 6,
71+
'cora': {
72+
'Case_Based': 0,
73+
'Genetic_Algorithms': 1,
74+
'Neural_Networks': 2,
75+
'Probabilistic_Methods': 3,
76+
'Reinforcement_Learning': 4,
77+
'Rule_Learning': 5,
78+
'Theory': 6
79+
},
80+
'citeseer': {
81+
'AI': 0,
82+
'IR': 1,
83+
'HCI': 2,
84+
'DB': 3,
85+
'ML': 4,
86+
'Agents': 5
87+
}
7888
}
7989

8090
# Convert to onehot label
81-
num_classes = len(label_index)
91+
num_classes = len(label_index[dataset])
8292
onehot_labels = np.zeros((len(labels), num_classes))
8393
idx = 0
8494
for s in labels:
85-
onehot_labels[idx, label_index[s]] = 1
95+
onehot_labels[idx, label_index[dataset][s]] = 1
8696
idx += 1
8797
return onehot_labels
8898

@@ -115,23 +125,30 @@ def sparse_matrix_to_tf_sparse_tensor(matrix):
115125

116126

117127
def load_dataset(dataset, sparse_features, normalize_adj):
118-
"""Loads Cora dataset."""
128+
"""Loads dataset."""
119129
dir_path = os.path.join('data', dataset)
120130
content_path = os.path.join(dir_path, '{}.content'.format(dataset))
121131
citation_path = os.path.join(dir_path, '{}.cites'.format(dataset))
122132

123133
content = np.genfromtxt(content_path, dtype=np.dtype(str))
124-
125-
idx = np.array(content[:, 0], dtype=np.int32)
134+
idx = np.array(content[:, 0])
126135
features = sp.csr_matrix(content[:, 1:-1], dtype=np.float32)
127-
labels = encode_onehot(content[:, -1])
136+
labels = encode_onehot(dataset, content[:, -1])
128137

129138
# Dict which maps paper id to data id
130139
idx_map = {j: i for i, j in enumerate(idx)}
131-
edges_unordered = np.genfromtxt(citation_path, dtype=np.int32)
140+
edges_unordered = np.genfromtxt(citation_path, dtype=np.dtype(str))
132141
edges = np.array(
133142
list(map(idx_map.get, edges_unordered.flatten())),
134-
dtype=np.int32).reshape(edges_unordered.shape)
143+
dtype=np.dtype(str)).reshape(edges_unordered.shape)
144+
145+
# Delete relation which the nodes appear in cites but not in content
146+
del_rel = []
147+
for i, j in enumerate(edges):
148+
if j[0] == 'None' or j[1] == 'None':
149+
del_rel.append(i)
150+
edges = np.delete(edges, del_rel, 0)
151+
135152
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
136153
shape=(labels.shape[0], labels.shape[0]),
137154
dtype=np.float32)
@@ -145,13 +162,13 @@ def load_dataset(dataset, sparse_features, normalize_adj):
145162
if normalize_adj:
146163
adj = normalize_adj_matrix(adj)
147164

148-
# 5% for train, 300 for validation, 1000 for test
149165
idx_train = slice(140)
150166
idx_val = slice(200, 500)
151167
idx_test = slice(500, 1500)
152168

153169
features = tf.convert_to_tensor(np.array(features.todense()))
154170
labels = tf.convert_to_tensor(np.where(labels)[1])
171+
155172
if sparse_features:
156173
adj = sparse_matrix_to_tf_sparse_tensor(adj)
157174
else:

0 commit comments

Comments
 (0)