@@ -27,6 +27,7 @@ def __init__(
27
27
model_class : Type [nn .Module ],
28
28
model_kwargs : Dict [str , Any ],
29
29
loss_fun : Dict [str , Union [str , Callable ]],
30
+ task_levels : Dict [str , str ],
30
31
random_seed : int = 42 ,
31
32
optim_kwargs : Optional [Dict [str , Any ]] = None ,
32
33
torch_scheduler_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -69,6 +70,7 @@ def __init__(
69
70
70
71
self .target_nan_mask = target_nan_mask
71
72
self .multitask_handling = multitask_handling
73
+ self .task_levels = task_levels
72
74
self .task_norms = task_norms
73
75
74
76
super ().__init__ ()
@@ -95,22 +97,11 @@ def __init__(
95
97
)
96
98
eval_options [task ].check_metrics_validity ()
97
99
98
- # Work-around to retain task level when model_kwargs are modified for FullGraphFinetuningNetwork
99
- if "task_heads_kwargs" in model_kwargs .keys ():
100
- task_heads_kwargs = model_kwargs ["task_heads_kwargs" ]
101
- elif "pretrained_model_kwargs" in model_kwargs .keys ():
102
- # This covers finetuning cases where we finetune from the task_heads
103
- task_heads_kwargs = model_kwargs ["pretrained_model_kwargs" ]["task_heads_kwargs" ]
104
- else :
105
- raise ValueError ("incorrect model_kwargs" )
106
- self .task_heads_kwargs = task_heads_kwargs
107
-
108
100
self ._eval_options_dict : Dict [str , EvalOptions ] = eval_options
109
101
self ._eval_options_dict = {
110
102
self ._get_task_key (
111
- task_level = task_heads_kwargs [key ][ "task_level" ],
103
+ task_level = task_levels [key ],
112
104
task = key
113
- # task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key
114
105
): value
115
106
for key , value in self ._eval_options_dict .items ()
116
107
}
@@ -119,22 +110,10 @@ def __init__(
119
110
120
111
self .model = self ._model_options .model_class (** self ._model_options .model_kwargs )
121
112
122
- # Maintain module map to easily select modules
123
- # We now need to define the module_map in pretrained_model in FinetuningNetwork
124
- # self._module_map = OrderedDict(
125
- # pe_encoders=self.model.encoder_manager,
126
- # pre_nn=self.model.pre_nn,
127
- # pre_nn_edges=self.model.pre_nn_edges,
128
- # gnn=self.model.gnn,
129
- # graph_output_nn=self.model.task_heads.graph_output_nn,
130
- # task_heads=self.model.task_heads.task_heads,
131
- # )
132
-
133
113
loss_fun = {
134
114
self ._get_task_key (
135
- task_level = task_heads_kwargs [key ][ "task_level" ],
115
+ task_level = task_levels [key ],
136
116
task = key
137
- # task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key
138
117
): value
139
118
for key , value in loss_fun .items ()
140
119
}
@@ -338,7 +317,7 @@ def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool)
338
317
preds = {k : preds [ii ] for ii , k in enumerate (targets_dict .keys ())}
339
318
340
319
preds = {
341
- self ._get_task_key (task_level = self .task_heads_kwargs [key ][ "task_level" ], task = key ): value
320
+ self ._get_task_key (task_level = self .task_levels [key ], task = key ): value
342
321
for key , value in preds .items ()
343
322
}
344
323
# preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())}
0 commit comments