-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_old_kg.py
301 lines (268 loc) · 50.6 KB
/
main_old_kg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import torch
import argparse
import numpy as np
from modules.tokenizers import Tokenizer
from modules.dataloaders import R2DataLoader
from modules.metrics import compute_scores
from modules.optimizers import build_optimizer, build_lr_scheduler
from modules.trainer import Trainer
from modules.loss import compute_loss
from models.r2gen import R2GenModel
from modules.mlclassifier import GCNClassifier
def parse_agrs():
parser = argparse.ArgumentParser()
# Data input settings
parser.add_argument('--image_dir', type=str, default='data/iu_xray/images/', help='the path to the directory containing the data.')
parser.add_argument('--ann_path', type=str, default='data/iu_xray/annotation.json', help='the path to the directory containing the data.')
# Data loader settings
parser.add_argument('--dataset_name', type=str, default='iu_xray', choices=['iu_xray', 'mimic_cxr', 'mimic_cxr_2images'], help='the dataset to be used.')
parser.add_argument('--max_seq_length', type=int, default=60, help='the maximum sequence length of the reports.')
parser.add_argument('--threshold', type=int, default=3, help='the cut off frequency for the words.')
parser.add_argument('--num_workers', type=int, default=4, help='the number of workers for dataloader.')
parser.add_argument('--batch_size', type=int, default=16, help='the number of samples for a batch')
# Model settings (for visual extractor)
#edit
parser.add_argument('--visual_extractor', type=str, default='densenet121', help='the visual extractor to be used.')
parser.add_argument('--visual_extractor_pretrained', type=bool, default=True, help='whether to load the pretrained visual extractor')
# Model settings (for Transformer)
parser.add_argument('--d_model', type=int, default=512, help='the dimension of Transformer.')
parser.add_argument('--d_ff', type=int, default=512, help='the dimension of FFN.')
#edit
parser.add_argument('--d_vf', type=int, default=2048, help='the dimension of the patch features.')
parser.add_argument('--num_heads', type=int, default=8, help='the number of heads in Transformer.')
parser.add_argument('--num_layers', type=int, default=3, help='the number of layers of Transformer.')
parser.add_argument('--dropout', type=float, default=0.1, help='the dropout rate of Transformer.')
parser.add_argument('--logit_layers', type=int, default=1, help='the number of the logit layer.')
parser.add_argument('--bos_idx', type=int, default=0, help='the index of <bos>.')
parser.add_argument('--eos_idx', type=int, default=0, help='the index of <eos>.')
parser.add_argument('--pad_idx', type=int, default=0, help='the index of <pad>.')
parser.add_argument('--use_bn', type=int, default=0, help='whether to use batch normalization.')
parser.add_argument('--drop_prob_lm', type=float, default=0.5, help='the dropout rate of the output layer.')
# for Relational Memory
parser.add_argument('--rm_num_slots', type=int, default=3, help='the number of memory slots.')
parser.add_argument('--rm_num_heads', type=int, default=8, help='the numebr of heads in rm.')
parser.add_argument('--rm_d_model', type=int, default=512, help='the dimension of rm.')
# Sample related
parser.add_argument('--sample_method', type=str, default='beam_search', help='the sample methods to sample a report.')
parser.add_argument('--beam_size', type=int, default=3, help='the beam size when beam searching.')
parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when sampling.')
parser.add_argument('--sample_n', type=int, default=1, help='the sample number per image.')
parser.add_argument('--group_size', type=int, default=1, help='the group size.')
parser.add_argument('--output_logsoftmax', type=int, default=1, help='whether to output the probabilities.')
parser.add_argument('--decoding_constraint', type=int, default=0, help='whether decoding constraint.')
parser.add_argument('--block_trigrams', type=int, default=1, help='whether to use block trigrams.')
# Trainer settings
parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.')
parser.add_argument('--epochs', type=int, default=100, help='the number of training epochs.')
parser.add_argument('--save_dir', type=str, default='results/iu_xray', help='the patch to save the models.')
parser.add_argument('--record_dir', type=str, default='records/', help='the patch to save the results of experiments')
parser.add_argument('--save_period', type=int, default=1, help='the saving period.')
parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.')
parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.')
parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.')
# Optimization
parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.')
parser.add_argument('--lr_ve', type=float, default=5e-5, help='the learning rate for the visual extractor.')
parser.add_argument('--lr_ed', type=float, default=1e-4, help='the learning rate for the remaining parameters.')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.')
parser.add_argument('--amsgrad', type=bool, default=True, help='.')
# Learning Rate Scheduler
parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.')
parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.')
parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.')
#KG
parser.add_argument('--pretrained', type=str, default='models/gcnclassifier_v2_ones3_t401v2t3_lr1e-6_e80.pth', help = 'path of pretrained GCN classifier')
parser.add_argument('--feed_mode', type=str, default = 'both', choices = ['both','cnn_only','gcn_only'], help = 'which features as the input of Transformer')
parser.add_argument('--kg_option', type = str, default = 'rgmg', choices = ['rgmg', 'vsegcn'], help = 'The knowledge graph used for iuxray dataset')
#Pretrained Language Models
parser.add_argument('--pretrained_LM', type=str, default = 'none', choices=['none','glove-mimic','biobert','bioalbert'], help = 'The pretrained language model used.')
parser.add_argument('--glove_path', type=str, default = 'models/glove_mimic-cxr_train.512.txt.gz', help = 'The path of pretrained language model GloVe-MIMIC.')
parser.add_argument('--bioalbert_path', type=str, default = 'models/bioalbert', help = 'The path of pretrained language model BioAlbert.')
# Others
parser.add_argument('--seed', type=int, default=9233, help='seed')
parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.')
parser.add_argument('--flip', type = bool, default = False, help = 'If True, 2 images will randomly switched positions at a probability of 0.5.')
args = parser.parse_args()
return args
def main():
# parse arguments
args = parse_agrs()
# fix random seeds
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
# create tokenizer
tokenizer = Tokenizer(args)
# create data loader
train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True)
val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False)
#edit
device = torch.device('cuda')
if args.dataset_name == 'iu_xray':
if args.kg_option == 'rgmg':
fw_adj = torch.tensor([
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
], dtype=torch.float,device=device)
elif args.kg_option == 'vsegcn':
fw_adj = torch.tensor([
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08624227881040383, 0.0, 0.0, 0.0, 0.0, 0.08531678128946102, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.01865074607267221, 0.0, 0.2924299133554616, 0.0, 0.0, 0.21304488410089617, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17180192556684684, 0.6418055548125825, 0.0, 0.5855658364897064, 0.0, 0.8458489347533727, 0.9602592859311171, 0.4274547554463511, 0.0, 0.7595885904689661, 0.0] ,
[0.0, 0.0, 0.01865074607267221, 0.0, 0.0, 0.4009853199098073, 1.5255730949811777, 0.0, 0.08521151259101123, 0.631755218959081, 0.0, 0.752383206747696, 0.0, 0.0, 0.331650626508743, 0.426960806313068, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7569183619130871, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.5610118144338234, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49754224048515705, 0.059920020742461305, 0.12193009552667401, 0.0, 0.6916982549261147, 0.0, 0.028649105523448872, 0.0, 1.38484543548606, 0.0, 0.0, 0.0, 0.6395125017555444] ,
[0.0, 0.0, 0.2924299133554616, 0.4009853199098073, 0.5610118144338234, 0.0, 1.522720025998771, 0.6039632016294225, 1.1321805681072825, 0.9342837995278563, 1.281557969181883, 0.8673131734216728, 1.2434062032175066, 0.3490255809790961, 0.6754221656115674, 0.4370111421665694, 0.0, 0.0, 0.0, 0.0, 0.40348845012792584, 0.0, 0.0, 0.06928636204125185, 0.0] ,
[0.0, 0.0, 0.0, 1.5255730949811777, 0.0, 1.522720025998771, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1794995623878415, 0.0, 0.0, 1.535623430834679, 0.0, 0.0, 0.06231769272515846, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.6039632016294225, 0.0, 0.0, 0.0, 1.5819475025089174, 0.0, 0.3162811291776416, 0.0, 0.5912241758519928, 0.7710172862925886, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8175373019274815, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.21304488410089617, 0.08521151259101123, 0.0, 1.1321805681072825, 0.0, 0.0, 0.0, 0.9061920646608415, 0.0, 0.7391379799976754, 0.0, 0.0, 1.1938741371126225, 0.0952618484445128, 0.0, 0.0, 0.05704063562431511, 0.0, 0.0, 0.16859312153006234, 0.0, 1.376195693906577, 0.0] ,
[0.0, 0.0, 0.0, 0.631755218959081, 0.0, 0.9342837995278563, 0.0, 1.5819475025089174, 0.9061920646608415, 0.0, 0.9602592859311171, 1.6911467944739096, 0.3242705192111205, 0.39747392323441544, 0.8649491061267922, 0.50827416218806, 0.0, 0.0, 0.0, 0.0, 0.623787049309904, 0.0, 0.021989647338186796, 0.0, 0.46624078048150774] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 1.281557969181883, 0.0, 0.0, 0.0, 0.9602592859311171, 0.0, 0.793205201267951, 0.0, 1.068148247942302, 1.535623430834679, 0.0, 0.0, 0.46778280083332285, 0.0, 0.0, 1.294461374017791, 0.0, 0.0, 0.0, 0.6669114759436587] ,
[0.0, 0.0, 0.0, 0.752383206747696, 0.0, 0.8673131734216728, 2.1794995623878415, 0.3162811291776416, 0.7391379799976754, 1.6911467944739096, 0.793205201267951, 0.0, 0.0, 0.007276287257039512, 1.3910422020235713, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23358941333252825, 0.0, 0.0, 0.3693909544915901, 0.29918669581834156] ,
[0.0, 0.0, 0.0, 0.0, 0.49754224048515705, 1.2434062032175066, 0.0, 0.0, 0.0, 0.3242705192111205, 0.0, 0.0, 0.0, 0.4321594812223054, 0.0, 0.20648748355473706, 0.2166398550187551, 0.0, 0.16826627073453929, 0.0, 0.6584726072977941, 0.0, 0.0, 0.0, 1.4172170703435527] ,
[0.0, 0.0, 0.0, 0.0, 0.059920020742461305, 0.3490255809790961, 0.0, 0.5912241758519928, 0.0, 0.39747392323441544, 1.068148247942302, 0.007276287257039512, 0.4321594812223054, 0.0, 0.6161631241992449, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1179703724409795, 0.35302216066358155, 0.0, 0.6443340011659412, 0.0] ,
[0.0, 0.0, 0.17180192556684684, 0.331650626508743, 0.12193009552667401, 0.6754221656115674, 1.535623430834679, 0.7710172862925886, 1.1938741371126225, 0.8649491061267922, 1.535623430834679, 1.3910422020235713, 0.0, 0.6161631241992449, 0.0, 0.5240225191561991, 0.0, 0.0, 0.0, 0.0, 1.0937906785556397, 0.3096717197899676, 0.0, 1.430262915176853, 0.3484577448251243] ,
[0.0, 0.0, 0.6418055548125825, 0.426960806313068, 0.0, 0.4370111421665694, 0.0, 0.0, 0.0952618484445128, 0.50827416218806, 0.0, 0.0, 0.20648748355473706, 0.0, 0.5240225191561991, 0.0, 0.0, 0.0, 0.0, 0.2272906111845001, 0.5060040136535207, 0.0, 0.3096717197899676, 0.4186620034983729, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.6916982549261147, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2166398550187551, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10158746275990369, 0.029803617870273677, 0.0, 0.0, 0.137502534460031, 0.0, 0.07092804383736119] ,
[0.0, 0.08624227881040383, 0.5855658364897064, 0.0, 0.0, 0.0, 0.06231769272515846, 0.0, 0.0, 0.0, 0.46778280083332285, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1461991767058608, 0.845848934753373, 0.0, 0.0, 0.9958502310338198, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.028649105523448872, 0.0, 0.0, 0.0, 0.05704063562431511, 0.0, 0.0, 0.0, 0.16826627073453929, 0.0, 0.0, 0.0, 0.10158746275990369, 0.1461991767058608, 0.0, 0.08964361822629091, 0.0034771927022252134, 0.0, 0.08912895017581536, 0.0, 0.06907447518803858] ,
[0.0, 0.0, 0.8458489347533727, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2272906111845001, 0.029803617870273677, 0.845848934753373, 0.08964361822629091, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.9602592859311171, 0.0, 1.38484543548606, 0.40348845012792584, 0.0, 0.8175373019274815, 0.0, 0.623787049309904, 1.294461374017791, 0.23358941333252825, 0.6584726072977941, 2.1179703724409795, 1.0937906785556397, 0.5060040136535207, 0.0, 0.0, 0.0034771927022252134, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.4274547554463511, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16859312153006234, 0.0, 0.0, 0.0, 0.0, 0.35302216066358155, 0.3096717197899676, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49199327658392233, 0.6449325692248835] ,
[0.0, 0.08531678128946102, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.021989647338186796, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3096717197899676, 0.137502534460031, 0.9958502310338198, 0.08912895017581536, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.7595885904689661, 0.7569183619130871, 0.0, 0.06928636204125185, 0.0, 0.0, 1.376195693906577, 0.0, 0.0, 0.3693909544915901, 0.0, 0.6443340011659412, 1.430262915176853, 0.4186620034983729, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49199327658392233, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.6395125017555444, 0.0, 0.0, 0.0, 0.0, 0.46624078048150774, 0.6669114759436587, 0.29918669581834156, 1.4172170703435527, 0.0, 0.3484577448251243, 0.0, 0.07092804383736119, 0.0, 0.06907447518803858, 0.0, 0.0, 0.6449325692248835, 0.0, 0.0, 0.0] ,
],dtype = torch.float, device = device)
else:
print('INVALID KG OPTION!')
if args.dataset_name == 'mimic_cxr' or args.dataset_name == 'mimic_cxr_2images':
fw_adj = torch.tensor([
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ,
[0.0, 0.0, 0.554004673298568, 0.3230466430334976, 0.2631825039749924, 0.7016287008480984, 0.24746163509836083, 0.0010015098042039912, 0.5068736479106769, 0.0, 0.9524974820767607, 0.0, 0.274863833501621, 0.41494699762748494, 0.5840689600939755, 0.0, 0.0, 0.1302442384969287, 0.0, 0.6339873741551628, 0.0, 0.0, 0.0, 0.0, 0.13757450134915059, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20484800268752174, 0.0, 0.3451355245988348, 0.0, 0.0, 0.31496065589330335] ,
[0.0, 0.554004673298568, 0.0, 0.0, 0.8777428349239078, 0.48756689514453816, 0.15609090494487807, 0.0, 0.2680360388191779, 0.0, 0.6199516595911282, 0.42238326753620825, 0.11635380051050957, 0.0020601220064393085, 0.6321690244338739, 0.0, 0.0, 0.08334337914718853, 0.5194350922631013, 0.12854873550846002, 0.6928120918404297, 0.0, 0.0, 0.9837299370816517, 0.2314271028213519, 0.5817251164456204, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33811183663241207, 0.0, 0.13936476949018967, 0.0, 0.0, 0.07089242056885223] ,
[0.0, 0.3230466430334976, 0.0, 0.0, 0.49497396267096466, 0.6594324804768135, 0.0, 0.706453671957811, 0.4967156785324145, 0.0, 0.8143326399158364, 0.09215690907124767, 1.0617056232382251, 0.1725070158434512, 0.3677521118316441, 0.0, 0.0, 0.0, 0.0, 0.31306543839429407, 1.8437693102858905, 0.0, 0.0, 0.0, 0.12775432598701517, 0.0, 0.1052712068749145, 0.0, 0.0, 0.0, 0.0, 0.037260086847355586, 0.0, 0.7070285247253186, 0.0, 0.0, 0.36112875829496877] ,
[0.0, 0.2631825039749924, 0.8777428349239078, 0.49497396267096466, 0.0, 0.5081730413487658, 0.0, 0.0, 0.4345747745875523, 0.0, 0.9271810641740084, 0.0, 0.43551881399664394, 0.0, 0.32510239740035224, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5363277306682555, 0.0, 0.0, 0.35190615068655606, 0.5705232126845626, 0.5633978920587661, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6523482076409699, 0.05734284387314562, 0.029039130518496915, 0.0, 0.0, 0.47600439204660566] ,
[0.0, 0.7016287008480984, 0.48756689514453816, 0.6594324804768135, 0.5081730413487658, 0.0, 0.5879416440743396, 0.7282368105665861, 0.6655288739610595, 0.0, 0.5829843060689561, 1.1084681450499108, 0.23022619090918064, 0.7869908341567744, 0.9237236161240492, 0.0, 0.0, 0.00679805799268029, 0.3508343765662928, 0.7283807004431154, 1.2142749896595786, 0.0, 0.0, 0.21380809279477983, 0.0, 0.0, 0.0, 0.7384514725987071, 0.036197082879009246, 0.024547912417246583, 0.0, 0.0031487780150971255, 0.0, 0.17255665551694127, 0.0, 0.30464217507912805, 0.5800348339999603] ,
[0.0, 0.24746163509836083, 0.15609090494487807, 0.0, 0.0, 0.5879416440743396, 0.0, 0.0, 0.0, 0.0, 0.050047418864644054, 1.1040108892060951, 0.0, 1.0758558236410007, 0.20507481280870768, 0.0, 0.0, 0.20482994095079318, 0.20102931464117388, 0.5609723973768375, 0.5907785254095393, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1108353406999546, 2.2722344254125035, 0.6708537856962394, 0.0, 0.0, 0.0, 0.656008131578071, 0.019380062443701017, 0.0] ,
[0.0, 0.0010015098042039912, 0.0, 0.706453671957811, 0.0, 0.7282368105665861, 0.0, 0.0, 0.8928317911567151, 0.0, 0.2412417134011427, 1.1756920438484275, 0.45754054092273994, 0.0, 0.0, 0.07862826830847675, 0.0, 0.08084792186159706, 0.17752298133873087, 0.0, 0.23167676395941764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22902820929658516, 0.0, 0.18822158336999106, 1.3653688666139736, 0.0, 0.0, 0.08262146112756505, 0.8291646636467345, 0.3474221602648512, 0.028679991724856257, 0.15823233737507106] ,
[0.0, 0.5068736479106769, 0.2680360388191779, 0.4967156785324145, 0.4345747745875523, 0.6655288739610595, 0.0, 0.8928317911567151, 0.0, 0.0, 0.5266439565166569, 0.7732181640322486, 0.8796823806455875, 0.04357377094721733, 0.38483053231749587, 0.0, 0.0, 0.01715279715950329, 0.0, 0.1935198098014637, 0.7726446050833446, 0.0, 0.0, 0.01203069298044543, 0.009614236915435854, 0.0, 0.028445852424580496, 0.0, 0.017029382779510584, 0.1773653331099141, 0.0, 0.0, 0.0, 0.7478386613505874, 0.0, 0.0, 0.3988597091752934] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11819146504902717, 0.3285711228697073, 0.0, 0.0, 0.0, 0.0, 0.1284847079738626, 0.3441683893946561, 0.0, 0.0, 0.0, 0.056702571766598944, 0.19444256855484526, 0.0, 0.0, 0.034767887372249416, 0.0, 0.05535863446625411, 0.0, 0.0, 0.2006116936209859, 0.0] ,
[0.0, 0.9524974820767607, 0.6199516595911282, 0.8143326399158364, 0.9271810641740084, 0.5829843060689561, 0.050047418864644054, 0.2412417134011427, 0.5266439565166569, 0.0, 0.0, 0.27606588231803425, 0.2702763655065547, 0.5155464612930855, 0.5345046760872463, 0.0, 0.0, 0.0, 0.0, 0.6619248249548156, 0.11676497651446016, 0.0, 0.0, 0.012298886029973912, 0.36531125789699864, 0.22014333096829503, 0.0, 0.0, 0.1421597490181444, 0.0, 0.0, 0.5023838505056603, 0.007094738878332813, 0.16482707917210404, 0.12422641272658975, 0.0, 0.16806740202956802] ,
[0.0, 0.0, 0.42238326753620825, 0.09215690907124767, 0.0, 1.1084681450499108, 1.1040108892060951, 1.1756920438484275, 0.7732181640322486, 0.0, 0.27606588231803425, 0.0, 0.4571939169349286, 0.44285881406541816, 0.014435627721654973, 0.0, 0.0, 0.0, 0.35645685756909534, 0.3273858184572163, 0.0, 0.0, 0.0, 0.0, 0.22124160687808572, 0.11022312228841263, 0.04830169724837715, 0.0, 0.92753653417296, 0.8948599965191218, 0.15171053363715095, 0.28220534987677603, 0.3001183720922399, 0.5143515680120436, 0.2778525728199598, 0.0, 0.0] ,
[0.0, 0.274863833501621, 0.11635380051050957, 1.0617056232382251, 0.43551881399664394, 0.23022619090918064, 0.0, 0.45754054092273994, 0.8796823806455875, 0.0, 0.2702763655065547, 0.4571939169349286, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.448038949914647, 0.0, 0.0, 0.07261956980307784, 0.0, 0.0, 0.07421899390203622, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.594487035114069, 0.0, 0.0, 0.0] ,
[0.0, 0.41494699762748494, 0.0020601220064393085, 0.1725070158434512, 0.0, 0.7869908341567744, 1.0758558236410007, 0.0, 0.04357377094721733, 0.0, 0.5155464612930855, 0.44285881406541816, 0.0, 0.0, 0.9595683261315329, 0.0, 0.0, 0.0, 0.0, 2.4365238654909045, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4720717303053189, 0.5727697527229471, 0.0, 0.0, 0.0, 0.9575685889709006, 1.2377447423336991, 0.0, 0.0] ,
[0.0, 0.5840689600939755, 0.6321690244338739, 0.3677521118316441, 0.32510239740035224, 0.9237236161240492, 0.20507481280870768, 0.0, 0.38483053231749587, 0.0, 0.5345046760872463, 0.014435627721654973, 0.0, 0.9595683261315329, 0.0, 0.17363312653927349, 0.19895016785061964, 0.0, 0.0, 1.208717109379973, 0.0, 0.015167532069747917, 0.0, 0.06703478683378344, 0.6177554201058338, 0.4406078996794569, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3022525474987056, 0.0884563575795982, 0.0, 1.2708509083599127] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07862826830847675, 0.0, 0.11819146504902717, 0.0, 0.0, 0.0, 0.0, 0.17363312653927349, 0.0, 2.447688346482624, 0.17983209096647654, 0.0, 0.0, 0.0, 0.28338141933933964, 0.0, 0.0, 0.29702177034947524, 0.07753776902134689, 0.007825343616291445, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5708374900647346, 0.0, 0.15510336927054139, 0.3542350803376888, 0.32669031419868527] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3285711228697073, 0.0, 0.0, 0.0, 0.0, 0.19895016785061964, 2.447688346482624, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14503213452982808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.139880360686599, 0.0, 0.1439133818303898, 0.0, 0.0] ,
[0.0, 0.1302442384969287, 0.08334337914718853, 0.0, 0.0, 0.00679805799268029, 0.20482994095079318, 0.08084792186159706, 0.01715279715950329, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.17983209096647654, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4124407401173898, 0.27031693276904833, 0.6962440948684386, 0.01868026565576617, 0.0, 0.058711117532272095, 0.3108414982982819, 0.8917732120346288, 0.5028865181216288, 0.3710258416111474, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.5194350922631013, 0.0, 0.0, 0.3508343765662928, 0.20102931464117388, 0.17752298133873087, 0.0, 0.0, 0.0, 0.35645685756909534, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3095615625659667, 0.0, 0.0, 0.33960729056256184, 0.14426353106038403, 0.48295034078748456, 0.04096791345986367, 0.0, 0.21459856249590378, 0.3394630945304, 0.9426984665379273, 0.22121090818098985, 0.3817453264511766, 0.3146902017201184, 0.2525866980495684, 0.041488578522214464, 0.0] ,
[0.0, 0.6339873741551628, 0.12854873550846002, 0.31306543839429407, 0.0, 0.7283807004431154, 0.5609723973768375, 0.0, 0.1935198098014637, 0.0, 0.6619248249548156, 0.3273858184572163, 0.0, 2.4365238654909045, 1.208717109379973, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.023645690589139633, 0.15053136759027846, 0.0, 0.24360268371473795, 0.0, 0.8821797270781289, 1.442822630962518, 0.0, 0.21531865324924443] ,
[0.0, 0.0, 0.6928120918404297, 1.8437693102858905, 0.5363277306682555, 1.2142749896595786, 0.5907785254095393, 0.23167676395941764, 0.7726446050833446, 0.0, 0.11676497651446016, 0.0, 1.448038949914647, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3095615625659667, 0.0, 0.0, 0.0, 0.0, 0.33727296191859424, 0.0, 0.11424727258813801, 0.004340630185881583, 0.0, 0.0, 0.17398826794432173, 0.699350130525858, 0.0, 0.0, 0.4114035987596012, 0.0, 0.8603703454658308, 0.6698830923247899] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1284847079738626, 0.0, 0.0, 0.0, 0.0, 0.015167532069747917, 0.28338141933933964, 0.14503213452982808, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6128893031292931, 0.0, 0.0, 0.0, 0.3311971887232972, 0.0, 0.221278095246663, 0.7821347726775432, 0.4309320418653175] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3441683893946561, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04939114176166614, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.9837299370816517, 0.0, 0.35190615068655606, 0.21380809279477983, 0.0, 0.0, 0.01203069298044543, 0.0, 0.012298886029973912, 0.0, 0.07261956980307784, 0.0, 0.06703478683378344, 0.0, 0.0, 0.4124407401173898, 0.33960729056256184, 0.0, 0.33727296191859424, 0.0, 0.0, 0.0, 0.0, 0.08979034229911677, 0.0, 0.38518142976105196, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.13757450134915059, 0.2314271028213519, 0.12775432598701517, 0.5705232126845626, 0.0, 0.0, 0.0, 0.009614236915435854, 0.0, 0.36531125789699864, 0.22124160687808572, 0.0, 0.0, 0.6177554201058338, 0.29702177034947524, 0.0, 0.27031693276904833, 0.14426353106038403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.84531965510853, 0.0, 0.0, 0.0, 0.0, 0.03024606051820196, 0.9015367412440394, 0.4187448648843476, 0.11315910768263383, 1.2343811011663848, 0.0, 0.24145505239298407] ,
[0.0, 0.0, 0.5817251164456204, 0.0, 0.5633978920587661, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22014333096829503, 0.11022312228841263, 0.0, 0.0, 0.4406078996794569, 0.07753776902134689, 0.0, 0.6962440948684386, 0.48295034078748456, 0.0, 0.11424727258813801, 0.0, 0.0, 0.08979034229911677, 2.84531965510853, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3231554047178068, 1.2532084879227057, 0.5614893243263129, 0.046137943483739924, 1.540851917651678, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.1052712068749145, 0.0, 0.0, 0.0, 0.22902820929658516, 0.028445852424580496, 0.056702571766598944, 0.0, 0.04830169724837715, 0.07421899390203622, 0.0, 0.0, 0.007825343616291445, 0.0, 0.01868026565576617, 0.04096791345986367, 0.0, 0.004340630185881583, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0564784860967394, 0.12567556923304798, 0.06594334312122997, 0.0, 0.0, 0.06688509334837037, 0.36269590043980027, 0.030319915064814112, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.7384514725987071, 0.0, 0.0, 0.0, 0.19444256855484526, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04939114176166614, 0.38518142976105196, 0.0, 0.0, 0.0564784860967394, 0.0, 0.0, 0.0, 0.0, 0.2104088857291676, 0.4561074463243395, 0.0, 0.0, 0.0, 0.6247400008366559] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.036197082879009246, 0.1108353406999546, 0.18822158336999106, 0.017029382779510584, 0.0, 0.1421597490181444, 0.92753653417296, 0.0, 0.4720717303053189, 0.0, 0.0, 0.0, 0.058711117532272095, 0.21459856249590378, 0.023645690589139633, 0.0, 0.6128893031292931, 0.0, 0.0, 0.0, 0.0, 0.12567556923304798, 0.0, 0.0, 0.4309948864530749, 0.0, 0.02542734722301152, 0.0, 0.0, 0.2510742769865065, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.024547912417246583, 2.2722344254125035, 1.3653688666139736, 0.1773653331099141, 0.0, 0.0, 0.8948599965191218, 0.0, 0.5727697527229471, 0.0, 0.0, 0.0, 0.3108414982982819, 0.3394630945304, 0.15053136759027846, 0.17398826794432173, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06594334312122997, 0.0, 0.4309948864530749, 0.0, 0.8692491673212553, 0.0, 0.19435111707051575, 0.266221588915103, 0.676971258598654, 0.20262509195680156, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6708537856962394, 0.0, 0.0, 0.034767887372249416, 0.0, 0.15171053363715095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8917732120346288, 0.9426984665379273, 0.0, 0.699350130525858, 0.0, 0.0, 0.0, 0.03024606051820196, 0.3231554047178068, 0.0, 0.0, 0.0, 0.8692491673212553, 0.0, 1.1605097349751408, 0.3461994713504898, 0.0, 1.5082427073912366, 0.0, 0.0] ,
[0.0, 0.20484800268752174, 0.33811183663241207, 0.037260086847355586, 0.6523482076409699, 0.0031487780150971255, 0.0, 0.0, 0.0, 0.0, 0.5023838505056603, 0.28220534987677603, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5028865181216288, 0.22121090818098985, 0.24360268371473795, 0.0, 0.0, 0.0, 0.0, 0.9015367412440394, 1.2532084879227057, 0.0, 0.2104088857291676, 0.02542734722301152, 0.0, 1.1605097349751408, 0.0, 1.681814915277294, 0.0, 0.43925238577431364, 0.04338409257246071, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.05734284387314562, 0.0, 0.0, 0.08262146112756505, 0.0, 0.05535863446625411, 0.007094738878332813, 0.3001183720922399, 0.0, 0.0, 0.0, 1.5708374900647346, 2.139880360686599, 0.3710258416111474, 0.3817453264511766, 0.0, 0.0, 0.3311971887232972, 0.0, 0.0, 0.4187448648843476, 0.5614893243263129, 0.06688509334837037, 0.4561074463243395, 0.0, 0.19435111707051575, 0.3461994713504898, 1.681814915277294, 0.0, 0.2674633965945188, 0.8740258958015569, 0.2474464019557586, 0.0] ,
[0.0, 0.3451355245988348, 0.13936476949018967, 0.7070285247253186, 0.029039130518496915, 0.17255665551694127, 0.0, 0.8291646636467345, 0.7478386613505874, 0.0, 0.16482707917210404, 0.5143515680120436, 0.594487035114069, 0.9575685889709006, 0.3022525474987056, 0.0, 0.0, 0.0, 0.3146902017201184, 0.8821797270781289, 0.4114035987596012, 0.0, 0.0, 0.0, 0.11315910768263383, 0.046137943483739924, 0.36269590043980027, 0.0, 0.0, 0.266221588915103, 0.0, 0.0, 0.2674633965945188, 0.0, 0.42342916886466814, 0.0, 0.33323103097930845] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.656008131578071, 0.3474221602648512, 0.0, 0.0, 0.12422641272658975, 0.2778525728199598, 0.0, 1.2377447423336991, 0.0884563575795982, 0.15510336927054139, 0.1439133818303898, 0.0, 0.2525866980495684, 1.442822630962518, 0.0, 0.221278095246663, 0.0, 0.0, 1.2343811011663848, 1.540851917651678, 0.030319915064814112, 0.0, 0.2510742769865065, 0.676971258598654, 1.5082427073912366, 0.43925238577431364, 0.8740258958015569, 0.42342916886466814, 0.0, 0.0, 0.0] ,
[0.0, 0.0, 0.0, 0.0, 0.0, 0.30464217507912805, 0.019380062443701017, 0.028679991724856257, 0.0, 0.2006116936209859, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3542350803376888, 0.0, 0.0, 0.041488578522214464, 0.0, 0.8603703454658308, 0.7821347726775432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20262509195680156, 0.0, 0.04338409257246071, 0.2474464019557586, 0.0, 0.0, 0.0, 0.0] ,
[0.0, 0.31496065589330335, 0.07089242056885223, 0.36112875829496877, 0.47600439204660566, 0.5800348339999603, 0.0, 0.15823233737507106, 0.3988597091752934, 0.0, 0.16806740202956802, 0.0, 0.0, 0.0, 1.2708509083599127, 0.32669031419868527, 0.0, 0.0, 0.0, 0.21531865324924443, 0.6698830923247899, 0.4309320418653175, 0.0, 0.0, 0.24145505239298407, 0.0, 0.0, 0.6247400008366559, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33323103097930845, 0.0, 0.0, 0.0] ,
],dtype = torch.float, device = device)
# Old adjacency matrix that used Test set for training
# fw_adj = torch.tensor([
# [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ,
# [0.0, 0.0, 0.4561275751719785, 0.1312437281509043, 0.24146207792929056, 0.5508571762992422, 0.1562972606280271, 0.0, 0.33109282796069434, 0.0, 0.7450978096449961, 0.0, 0.09811571952749415, 0.3566126916516296, 0.4550427125255811, 0.0, 0.0, 0.0, 0.0, 0.3675326731479569, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0742334537640819, 0.0, 0.0, 0.04911539003545221] ,
# [0.0, 0.4561275751719785, 0.0, 0.2239353858402398, 0.7210226952471281, 0.1813069680770458, 0.0788364926609991, 0.0, 0.15580397237632496, 0.0, 0.5213886531771248, 0.29374910091081574, 0.07338686671308393, 0.0, 0.47453376787334545, 0.0, 0.0, 0.0, 0.06713445127088644, 0.0, 0.2390801083216384, 0.0, 0.0, 0.5368211205907667, 0.0, 0.12777400002905853, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ,
# [0.0, 0.1312437281509043, 0.2239353858402398, 0.0, 0.6583894493157434, 0.4838989820766221, 0.0, 0.5838389742905232, 0.1802535145956454, 0.0, 0.7991245475975839, 0.11752099418944446, 1.06322936147921, 0.14944553664268462, 0.555751891083365, 0.0, 0.0, 0.0, 0.0, 0.0, 1.194781559134564, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06242749638070836, 0.0, 0.0, 0.0] ,
# [0.0, 0.24146207792929056, 0.7210226952471281, 0.6583894493157434, 0.0, 0.33739430398543896, 0.0, 0.0, 0.2753154828514788, 0.0, 0.7937543770840491, 0.0, 0.44113179799973046, 0.0, 0.3813008244727325, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2757852680983795, 0.0, 0.0, 0.10440811928574718, 0.32527161563918205, 0.30268868564421475, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4051634698397344, 0.0, 0.0, 0.0, 0.0, 0.23170964380526302] ,
# [0.0, 0.5508571762992422, 0.1813069680770458, 0.4838989820766221, 0.33739430398543896, 0.0, 0.5084212742711277, 0.6675018415064685, 0.4573178232633596, 0.0, 0.47523016454835715, 0.655060126377754, 0.04554159429343136, 0.690421666641882, 0.558453872850085, 0.0, 0.0, 0.0, 0.0, 0.20852065737883554, 0.6848377491773682, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2185149665526724, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04939763863424854] ,
# [0.0, 0.1562972606280271, 0.0788364926609991, 0.0, 0.0, 0.5084212742711277, 0.0, 0.0, 0.0, 0.0, 0.005426229300043639, 1.3057145824116705, 0.0, 0.8349693844237195, 0.008921790018546403, 0.0, 0.0, 0.11331204695360785, 0.09144882429915285, 0.4600716943289044, 0.4835214391246888, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02380115748556941, 2.1599410696116648, 0.5656942084851855, 0.0, 0.0, 0.0, 0.5385812163077162, 0.0, 0.0] ,
# [0.0, 0.0, 0.0, 0.5838389742905232, 0.0, 0.6675018415064685, 0.0, 0.0, 0.731672905950786, 0.0, 0.19873089907410066, 1.0717887940778525, 0.4314328875601624, 0.0, 0.0, 0.0, 0.0, 0.0, 0.084615347661925, 0.0, 0.12484394581839403, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13075929436986614, 0.0, 0.08945725659182202, 1.2587049974217708, 0.0, 0.0, 0.0, 0.7231133523008619, 0.24889659448837284, 0.0, 0.06274981602259916] ,
# [0.0, 0.33109282796069434, 0.15580397237632496, 0.1802535145956454, 0.2753154828514788, 0.4573178232633596, 0.0, 0.731672905950786, 0.0, 0.0, 0.3658987804839403, 0.5684867664090824, 0.8076453675499128, 0.0, 0.24600870168688302, 0.0, 0.0, 0.0, 0.0, 0.0798570028567058, 0.6598408690505047, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.06942284468142029, 0.0, 0.0, 0.0, 0.6390156086235599, 0.0, 0.0, 0.2877393797550929] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.37032811695643797, 0.5864507940803438, 0.09648724571559682, 0.08646096743214723, 0.0, 0.0, 0.3795124898565471, 0.5956873219192063, 0.0, 0.0, 0.029427346497570606, 0.3055219304384092, 0.44290494604282604, 0.18031374756716573, 0.0, 0.2817348748514817, 0.0, 0.3037353594240949, 0.0, 0.16902414094509366, 0.44897290997326605, 0.0] ,
# [0.0, 0.7450978096449961, 0.5213886531771248, 0.7991245475975839, 0.7937543770840491, 0.47523016454835715, 0.005426229300043639, 0.19873089907410066, 0.3658987804839403, 0.0, 0.0, 0.15917680219264457, 0.23815070153292747, 0.32301546536533154, 0.4733720905109685, 0.0, 0.0, 0.0, 0.0, 0.3050316392916182, 0.0, 0.0, 0.0, 0.0, 0.007496236022830504, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1459218682394831, 0.0, 0.0, 0.0, 0.0, 0.0] ,
# [0.0, 0.0, 0.29374910091081574, 0.11752099418944446, 0.0, 0.655060126377754, 1.3057145824116705, 1.0717887940778525, 0.5684867664090824, 0.0, 0.15917680219264457, 0.0, 0.34332677414193835, 0.2635635043486729, 0.031005417323580087, 0.0, 0.0, 0.0, 0.09089149252550566, 0.02437348630097624, 0.0, 0.0, 0.0, 0.0, 0.03277105749999201, 0.0, 0.0, 0.0, 0.6391742933567508, 0.6037980683638654, 0.0, 0.0, 0.0033714685259266476, 0.23107114275084603, 0.0301477199380383, 0.0, 0.0] ,
# [0.0, 0.09811571952749415, 0.07338686671308393, 1.06322936147921, 0.44113179799973046, 0.04554159429343136, 0.0, 0.4314328875601624, 0.8076453675499128, 0.0, 0.23815070153292747, 0.34332677414193835, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.3533266696615023, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5094212566296229, 0.0, 0.0, 0.0] ,
# [0.0, 0.3566126916516296, 0.0, 0.14944553664268462, 0.0, 0.690421666641882, 0.8349693844237195, 0.0, 0.0, 0.0, 0.32301546536533154, 0.2635635043486729, 0.0, 0.0, 0.6931424489513275, 0.0, 0.0, 0.0, 0.0, 2.1074196113031913, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.15331793735154012, 0.23884807415186537, 0.0, 0.0, 0.0, 0.6251962447501228, 0.9033262867029686, 0.0, 0.0] ,
# [0.0, 0.4550427125255811, 0.47453376787334545, 0.555751891083365, 0.3813008244727325, 0.558453872850085, 0.008921790018546403, 0.0, 0.24600870168688302, 0.0, 0.4733720905109685, 0.031005417323580087, 0.0, 0.6931424489513275, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7161636297822183, 0.0, 0.0, 0.0, 0.0, 0.11876223494432343, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7766379329419377] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.37032811695643797, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.826263272280364, 0.572616403287778, 0.37608135779331225, 0.0, 0.0, 0.6664791879765191, 0.0, 0.0, 0.6695968919616625, 0.4488791110312937, 0.38945634739693075, 0.08411274303185438, 0.2917004403335538, 0.3800437779653869, 0.2380677791988365, 0.27679858458872236, 1.9549590442141598, 0.3722905991057948, 0.5307931536036093, 0.7368876097778745, 0.7061522884727808] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5864507940803438, 0.0, 0.0, 0.0, 0.0, 0.0, 2.826263272280364, 0.0, 0.35120934068491455, 0.2375133841107677, 0.0, 0.0, 0.5321065210844121, 0.19927245248782205, 0.0, 0.14957868637482777, 0.0, 0.25536621908603513, 0.0, 0.0, 0.0, 0.0, 0.03734498368480446, 2.516696010229012, 0.0, 0.5307889978371716, 0.0, 0.0960955758632178] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11331204695360785, 0.0, 0.0, 0.09648724571559682, 0.0, 0.0, 0.0, 0.0, 0.0, 0.572616403287778, 0.35120934068491455, 0.0, 0.0, 0.0, 0.0, 0.29348082706206574, 0.33245104866618147, 0.7834253853708174, 0.6637012531860743, 1.0870265661824885, 0.39863505669400867, 0.13312286942605217, 0.4347084018504075, 0.6789980152639449, 1.2829314721133003, 0.8828246864504936, 0.750212390764988, 0.3394040245033589, 0.17754003752713748, 0.3162719096872125, 0.0] ,
# [0.0, 0.0, 0.06713445127088644, 0.0, 0.0, 0.0, 0.09144882429915285, 0.084615347661925, 0.0, 0.08646096743214723, 0.0, 0.09089149252550566, 0.0, 0.0, 0.0, 0.37608135779331225, 0.2375133841107677, 0.0, 0.0, 0.0, 1.6925691084401553, 0.3561732227230664, 0.027819402623596414, 0.7161674728252212, 0.5197157747849462, 0.8581083948136323, 0.42260822276000154, 0.3552253348597958, 0.5997658869837408, 0.720615333108806, 1.319266349117826, 0.5944386313930262, 0.7674724970376314, 0.693554702139206, 0.6387600967878838, 0.4217118279958016, 0.3030211175474031] ,
# [0.0, 0.3675326731479569, 0.0, 0.0, 0.0, 0.20852065737883554, 0.4600716943289044, 0.0, 0.0798570028567058, 0.0, 0.3050316392916182, 0.02437348630097624, 0.0, 2.1074196113031913, 0.7161636297822183, 0.0, 0.0, 0.0, 0.0, 0.0, 0.29382396236601094, 0.0, 0.0, 0.17754549443031645, 0.0, 0.0, 0.32355364182848556, 0.159505799102812, 0.411251947423955, 0.52515309826666, 0.0, 0.596335294846914, 0.013806128482410229, 1.2531067143858092, 1.8230394564602632, 0.13362385380628916, 0.6001536689980314] ,
# [0.0, 0.0, 0.2390801083216384, 1.194781559134564, 0.2757852680983795, 0.6848377491773682, 0.4835214391246888, 0.12484394581839403, 0.6598408690505047, 0.0, 0.0, 0.0, 1.3533266696615023, 0.0, 0.0, 0.0, 0.0, 0.0, 1.6925691084401553, 0.29382396236601094, 0.0, 0.0, 0.0, 0.7138805959551021, 0.11297953392649797, 0.4844330162160724, 0.38596546185743924, 0.036239842558855984, 0.0, 0.5547948133103655, 1.0803235644904228, 0.0, 0.0, 0.7882036912676615, 0.0, 1.240751663702056, 1.0477606549141325] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3795124898565471, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6664791879765191, 0.5321065210844121, 0.29348082706206574, 0.3561732227230664, 0.0, 0.0, 0.0, 0.2534515169611425, 0.0, 0.11246945995471487, 0.07334273346935428, 0.21668590043388075, 0.21941094489706722, 0.9975080075129525, 0.368692533676505, 0.21812689267933633, 0.3287631475798967, 0.7093092033967436, 0.1917101851895807, 0.6056386787874534, 1.1599331154514032, 0.8043554214386265] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5956873219192063, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19927245248782205, 0.33245104866618147, 0.027819402623596414, 0.0, 0.0, 0.2534515169611425, 0.0, 0.22725680593423186, 0.0, 0.0, 0.31555285272114897, 0.4267466727339617, 0.0, 0.1826978370690844, 0.0, 0.05379873696374713, 0.14234923917988185, 0.0, 0.08948062941912607, 0.20706340244730248, 0.0] ,
# [0.0, 0.0, 0.5368211205907667, 0.0, 0.10440811928574718, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7834253853708174, 0.7161674728252212, 0.17754549443031645, 0.7138805959551021, 0.0, 0.22725680593423186, 0.0, 0.0, 0.4592383775815212, 0.16538956868582963, 0.7629051452564362, 0.0, 0.05346010889771364, 0.0, 0.27509540397847715, 0.21213865741151783, 0.0, 0.22482877173867202, 0.0, 0.0] ,
# [0.0, 0.0, 0.0, 0.0, 0.32527161563918205, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007496236022830504, 0.03277105749999201, 0.0, 0.0, 0.11876223494432343, 0.6695968919616625, 0.14957868637482777, 0.6637012531860743, 0.5197157747849462, 0.0, 0.11297953392649797, 0.11246945995471487, 0.0, 0.0, 0.0, 3.2134702267299677, 0.24101556699326157, 0.09145686480864357, 0.12919761705292898, 0.05039604762473017, 0.403553572864001, 1.3069750190832774, 0.7980425960864843, 0.47648926941152736, 1.6040951817884799, 0.0, 0.6259654177637075] ,
# [0.0, 0.0, 0.12777400002905853, 0.0, 0.30268868564421475, 0.0, 0.0, 0.0, 0.0, 0.029427346497570606, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4488791110312937, 0.0, 1.0870265661824885, 0.8581083948136323, 0.0, 0.4844330162160724, 0.07334273346935428, 0.0, 0.4592383775815212, 3.2134702267299677, 0.0, 0.2203493961718295, 0.018110233591376573, 0.10255522134088697, 0.26908677536189435, 0.7137634299128567, 1.6439628202548011, 0.925575399120091, 0.4163664995892516, 1.9116571827560045, 0.0, 0.0] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13075929436986614, 0.0, 0.3055219304384092, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38945634739693075, 0.25536621908603513, 0.39863505669400867, 0.42260822276000154, 0.32355364182848556, 0.38596546185743924, 0.21668590043388075, 0.31555285272114897, 0.16538956868582963, 0.24101556699326157, 0.2203493961718295, 0.0, 0.4371114399604061, 0.5054285531888771, 0.44924515354001515, 0.3648403693789466, 0.28056669723711136, 0.4471627735959794, 0.7437313561632167, 0.4121877509333448, 0.2726999196691929, 0.18922614685864272] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.2185149665526724, 0.0, 0.0, 0.0, 0.44290494604282604, 0.0, 0.0, 0.0, 0.0, 0.0, 0.08411274303185438, 0.0, 0.13312286942605217, 0.3552253348597958, 0.159505799102812, 0.036239842558855984, 0.21941094489706722, 0.4267466727339617, 0.7629051452564362, 0.09145686480864357, 0.018110233591376573, 0.4371114399604061, 0.0, 0.20271260668257918, 0.3668239365548465, 0.1184126119193528, 0.5934154784628667, 0.8376418968029923, 0.21619014199138478, 0.28461494476804783, 0.34161510088391334, 1.008623439095296] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02380115748556941, 0.08945725659182202, 0.0, 0.18031374756716573, 0.0, 0.6391742933567508, 0.0, 0.15331793735154012, 0.0, 0.2917004403335538, 0.0, 0.4347084018504075, 0.5997658869837408, 0.411251947423955, 0.0, 0.9975080075129525, 0.0, 0.0, 0.12919761705292898, 0.10255522134088697, 0.5054285531888771, 0.20271260668257918, 0.0, 0.8187047289562672, 0.1410302025894172, 0.3941373798499791, 0.300273095819943, 0.043970911938040466, 0.6307373897434683, 0.09322704591348681, 0.0] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.1599410696116648, 1.2587049974217708, 0.06942284468142029, 0.0, 0.0, 0.6037980683638654, 0.0, 0.23884807415186537, 0.0, 0.3800437779653869, 0.0, 0.6789980152639449, 0.720615333108806, 0.52515309826666, 0.5547948133103655, 0.368692533676505, 0.1826978370690844, 0.05346010889771364, 0.05039604762473017, 0.26908677536189435, 0.44924515354001515, 0.3668239365548465, 0.8187047289562672, 0.0, 1.2341946413520681, 0.35781208803482784, 0.5890899178792679, 0.6340530114404032, 1.0585876788500728, 0.5803311766446675, 0.0] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5656942084851855, 0.0, 0.0, 0.2817348748514817, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2380677791988365, 0.0, 1.2829314721133003, 1.319266349117826, 0.0, 1.0803235644904228, 0.21812689267933633, 0.0, 0.0, 0.403553572864001, 0.7137634299128567, 0.3648403693789466, 0.1184126119193528, 0.1410302025894172, 1.2341946413520681, 0.0, 1.5406418880016948, 0.7296881633928688, 0.0, 1.8831346562682474, 0.3277653544960015, 0.0] ,
# [0.0, 0.0, 0.0, 0.0, 0.4051634698397344, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1459218682394831, 0.0, 0.0, 0.0, 0.0, 0.27679858458872236, 0.03734498368480446, 0.8828246864504936, 0.5944386313930262, 0.596335294846914, 0.0, 0.3287631475798967, 0.05379873696374713, 0.27509540397847715, 1.3069750190832774, 1.6439628202548011, 0.28056669723711136, 0.5934154784628667, 0.3941373798499791, 0.35781208803482784, 1.5406418880016948, 0.0, 2.051483967313616, 0.35982751520577233, 0.8118043924443167, 0.4208658024914072, 0.0] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3037353594240949, 0.0, 0.0033714685259266476, 0.0, 0.0, 0.0, 1.9549590442141598, 2.516696010229012, 0.750212390764988, 0.7674724970376314, 0.013806128482410229, 0.0, 0.7093092033967436, 0.14234923917988185, 0.21213865741151783, 0.7980425960864843, 0.925575399120091, 0.4471627735959794, 0.8376418968029923, 0.300273095819943, 0.5890899178792679, 0.7296881633928688, 2.051483967313616, 0.0, 0.6607118414843173, 1.251934458983633, 0.6240485023131742, 0.12626765101490742] ,
# [0.0, 0.0742334537640819, 0.0, 0.06242749638070836, 0.0, 0.0, 0.0, 0.7231133523008619, 0.6390156086235599, 0.0, 0.0, 0.23107114275084603, 0.5094212566296229, 0.6251962447501228, 0.0, 0.3722905991057948, 0.0, 0.3394040245033589, 0.693554702139206, 1.2531067143858092, 0.7882036912676615, 0.1917101851895807, 0.0, 0.0, 0.47648926941152736, 0.4163664995892516, 0.7437313561632167, 0.21619014199138478, 0.043970911938040466, 0.6340530114404032, 0.0, 0.35982751520577233, 0.6607118414843173, 0.0, 0.7989276546306372, 0.0, 0.7007917534875767] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5385812163077162, 0.24889659448837284, 0.0, 0.16902414094509366, 0.0, 0.0301477199380383, 0.0, 0.9033262867029686, 0.0, 0.5307931536036093, 0.5307889978371716, 0.17754003752713748, 0.6387600967878838, 1.8230394564602632, 0.0, 0.6056386787874534, 0.08948062941912607, 0.22482877173867202, 1.6040951817884799, 1.9116571827560045, 0.4121877509333448, 0.28461494476804783, 0.6307373897434683, 1.0585876788500728, 1.8831346562682474, 0.8118043924443167, 1.251934458983633, 0.7989276546306372, 0.0, 0.0, 0.0] ,
# [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.44897290997326605, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7368876097778745, 0.0, 0.3162719096872125, 0.4217118279958016, 0.13362385380628916, 1.240751663702056, 1.1599331154514032, 0.20706340244730248, 0.0, 0.0, 0.0, 0.2726999196691929, 0.34161510088391334, 0.09322704591348681, 0.5803311766446675, 0.3277653544960015, 0.4208658024914072, 0.6240485023131742, 0.0, 0.0, 0.0, 0.0406426973721225] ,
# [0.0, 0.04911539003545221, 0.0, 0.0, 0.23170964380526302, 0.04939763863424854, 0.0, 0.06274981602259916, 0.2877393797550929, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7766379329419377, 0.7061522884727808, 0.0960955758632178, 0.0, 0.3030211175474031, 0.6001536689980314, 1.0477606549141325, 0.8043554214386265, 0.0, 0.0, 0.6259654177637075, 0.0, 0.18922614685864272, 1.008623439095296, 0.0, 0.0, 0.0, 0.0, 0.12626765101490742, 0.7007917534875767, 0.0, 0.0406426973721225, 0.0] ,
# ],dtype = torch.float, device = device)
bw_adj = fw_adj.t()
# build model architecture
#edit removed num_classes from args
num_classes = fw_adj.shape[0] - 1
submodel = GCNClassifier(num_classes, fw_adj, bw_adj)
# submodel.state_dict = torch.load(args.pretrained)
state_dict = submodel.state_dict()
state_dict.update({k:v for k, v in torch.load(args.pretrained).items() if k in state_dict})
submodel.load_state_dict(state_dict)
model = R2GenModel(args, tokenizer, submodel)
# print(model)
# print(model.state_dict())
# raise Exception('lol')
# #edit
# if args.pretrained:
# pretrained_gcn = torch.load(args.pretrained)
# pretrained_state_dict = pretrained_gcn['model_state_dict']
# state_dict = model.state_dict()
# state_dict.update({k: v for k, v in pretrained_state_dict.items() if k in state_dict and 'fc' not in k})
# model.load_state_dict(state_dict)
# get function handles of loss and metrics
criterion = compute_loss
metrics = compute_scores
# build optimizer, learning rate scheduler
optimizer = build_optimizer(args, model)
lr_scheduler = build_lr_scheduler(args, optimizer)
# build trainer and start to train
trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader)
trainer.train()
if __name__ == '__main__':
main()