-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
112 lines (92 loc) · 3.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import yaml
import os
import torch
import random
import copy
import dgl
import numpy as np
from train import train_baseline, train_gnn
def parse_args():
parser = argparse.ArgumentParser(description='Self-explainable GNN')
parser.add_argument('--method', type=str, default='sunny-gnn', help='self-explainable GNN type',
choices=['sunny-gnn', 'gat', 'gcn'])
parser.add_argument('--encoder', type=str, default='gat', help='GNN encoder type',
choices=['gat', 'gcn'])
parser.add_argument('--dataset', type=str, default='citeseer', help='dataset name',
choices=['citeseer', 'cora', 'pubmed',
'amazon-photo', 'coauthor-physics', 'coauthor-cs'])
parser.add_argument('--gpu', type=int, default=0, help='gpu id')
parser.add_argument('--num_seeds', type=int, default=5, help='number of random seeds')
return parser.parse_args()
class Config(object):
def __init__(self, args):
abs_dir = os.path.dirname(os.path.realpath(__file__))
data_dir = os.path.join(abs_dir, 'dataset', args.dataset)
self.method = args.method
self.encoder_type = args.encoder
self.dataset = args.dataset
self.abs_dir = abs_dir
self.data_dir = data_dir
self.gpu = args.gpu
self.index = None
self.graph_path = f'{data_dir}/{args.dataset}_graph.bin'
self.index_path = f'{data_dir}/{args.dataset}_index.bin'
self.check_dataset()
self.ckpt_dir = os.path.join(abs_dir, 'ckpt')
self.hyparams = self.load_hyperparams(args)
def check_dataset(self):
if not os.path.exists(self.graph_path):
from tools.get_data import get_dataset
get_dataset(self.dataset, self.data_dir)
def load_hyperparams(self, args):
yml_path = os.path.join(self.abs_dir, 'configs', f'{args.dataset}.yml')
with open(yml_path, 'r') as f:
hyperparams = yaml.load(f, Loader=yaml.FullLoader)
return hyperparams
def set_seed(self, seed):
self.seed = seed
self.encoder_path = f'{self.ckpt_dir}/{self.dataset}/{self.encoder_type}-seed-{seed}-pretrain.pt'
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
dgl.seed(seed)
def main():
results = {}
for seed in range(args.num_seeds):
setup_seed(seed)
cfg.set_seed(seed)
print(f'===========seed: {seed}===========')
if cfg.method == 'sunny-gnn':
print(f"Dataset: {cfg.dataset}, Method: {cfg.method}-{cfg.encoder_type}")
if not os.path.exists(cfg.encoder_path):
print(f"Pretrain {cfg.encoder_type}...")
cfg_cp = copy.deepcopy(cfg)
cfg_cp.method = cfg_cp.encoder_type
train_gnn.train(cfg_cp)
print(f"Train {cfg.method}...")
metrics = train_baseline.train(cfg)
elif cfg.method in ['gat', 'gcn']:
print(f"Dataset: {cfg.dataset}, Method: {cfg.method}")
metrics = train_gnn.train(cfg)
else:
raise NotImplementedError
if results == {}:
for k, v in metrics.items():
results[k] = [v]
else:
for k, v in metrics.items():
results[k].append(v)
print(f'===========results===========')
for k, v in results.items():
results[k] = [v, sum(v) / len(v), np.std(v)]
print(f'>>> {k}: {sum(v) / len(v)}, {np.std(v)}')
return results
if __name__ == '__main__':
args = parse_args()
cfg = Config(args)
main()