-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
42 lines (38 loc) · 1.79 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
from hgcn_utils.train_utils import add_flags_from_config
config_args = {
'training_config': {
'log': (None, 'None for no logging'),
'lr': (0.001, 'learning rate'),
'batch-size': (10000, 'batch size'),
'epochs': (500, 'maximum number of epochs to train for'),
'weight-decay': (0.005, 'l2 regularization strength'),
'momentum': (0.95, 'momentum in optimizer'),
'seed': (1234, 'seed for data split'),
'train_seed': (1234, 'seed for training'),
'log-freq': (1, 'how often to compute print train/val metrics (in epochs)'),
'eval-freq': (20, 'how often to compute val metrics (in epochs)'),
'r': (2., 'fermi-dirac decoder parameter for lp'),
't': (1., 'fermi-dirac decoder parameter for lp'),
'i': (1, 'the number of iteration times')
},
'model_config': {
'embedding_dim': (50, 'user item embedding dimension'),
'scale': (0.1, 'scale for init'),
'dim': (50, 'embedding dimension'),
'network': ('resSumGCN', 'choice of StackGCNs, plainGCN, resSumGCN'),
'c': (1, 'hyperbolic radius, set to None for trainable curvature'),
'num-layers': (4, 'number of hidden layers in encoder'),
'margin': (0.1, 'margin value in the metric learning loss'),
'alpha': (20, "scale factor for geometric regularization")
},
'data_config': {
'dataset': ('Amazon-CD', 'which dataset to use'),
'num_neg': (1, 'number of negative samples'),
'test_ratio': (0.2, 'proportion of test edges for link prediction'),
'norm_adj': ('True', 'whether to row-normalize the adjacency matrix'),
}
}
parser = argparse.ArgumentParser()
for _, config_dict in config_args.items():
parser = add_flags_from_config(parser, config_dict)