Skip to content

Commit a3d4715

Browse files
author
WenkelF
committed
Making predictor model-unspecific
1 parent da0d058 commit a3d4715

File tree

5 files changed

+21
-27
lines changed

5 files changed

+21
-27
lines changed

graphium/cli/hydra.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def run_training_finetuning(cfg: DictConfig) -> None:
7373
model_class=model_class,
7474
model_kwargs=model_kwargs,
7575
metrics=metrics,
76+
task_levels=datamodule.get_task_levels(),
7677
accelerator_type=accelerator_type,
7778
task_norms=datamodule.task_norms,
7879
)

graphium/config/_loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def load_predictor(
278278
model_class: Type[torch.nn.Module],
279279
model_kwargs: Dict[str, Any],
280280
metrics: Dict[str, MetricWrapper],
281+
task_levels: Dict[str, str],
281282
accelerator_type: str,
282283
task_norms: Optional[Dict[Callable, Any]] = None,
283284
) -> PredictorModule:
@@ -302,6 +303,7 @@ def load_predictor(
302303
model_class=model_class,
303304
model_kwargs=model_kwargs,
304305
metrics=metrics,
306+
task_levels=task_levels,
305307
task_norms=task_norms,
306308
**cfg_pred,
307309
)

graphium/data/datamodule.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,18 @@ def _get_task_key(self, task_level: str, task: str):
936936
task = task_prefix + task
937937
return task
938938

939+
def get_task_levels(self):
940+
task_level_map = {}
941+
942+
for task, task_args in self.task_specific_args.items():
943+
if isinstance(task_args, DatasetProcessingParams):
944+
task_args = task_args.__dict__ # Convert the class to a dictionary
945+
task_level_map.update({
946+
task: task_args["task_level"]
947+
})
948+
949+
return task_level_map
950+
939951
def prepare_data(self):
940952
"""Called only from a single process in distributed settings. Steps:
941953

graphium/trainer/predictor.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
model_class: Type[nn.Module],
2828
model_kwargs: Dict[str, Any],
2929
loss_fun: Dict[str, Union[str, Callable]],
30+
task_levels: Dict[str, str],
3031
random_seed: int = 42,
3132
optim_kwargs: Optional[Dict[str, Any]] = None,
3233
torch_scheduler_kwargs: Optional[Dict[str, Any]] = None,
@@ -69,6 +70,7 @@ def __init__(
6970

7071
self.target_nan_mask = target_nan_mask
7172
self.multitask_handling = multitask_handling
73+
self.task_levels = task_levels
7274
self.task_norms = task_norms
7375

7476
super().__init__()
@@ -95,22 +97,11 @@ def __init__(
9597
)
9698
eval_options[task].check_metrics_validity()
9799

98-
# Work-around to retain task level when model_kwargs are modified for FullGraphFinetuningNetwork
99-
if "task_heads_kwargs" in model_kwargs.keys():
100-
task_heads_kwargs = model_kwargs["task_heads_kwargs"]
101-
elif "pretrained_model_kwargs" in model_kwargs.keys():
102-
# This covers finetuning cases where we finetune from the task_heads
103-
task_heads_kwargs = model_kwargs["pretrained_model_kwargs"]["task_heads_kwargs"]
104-
else:
105-
raise ValueError("incorrect model_kwargs")
106-
self.task_heads_kwargs = task_heads_kwargs
107-
108100
self._eval_options_dict: Dict[str, EvalOptions] = eval_options
109101
self._eval_options_dict = {
110102
self._get_task_key(
111-
task_level=task_heads_kwargs[key]["task_level"],
103+
task_level=task_levels[key],
112104
task=key
113-
# task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key
114105
): value
115106
for key, value in self._eval_options_dict.items()
116107
}
@@ -119,22 +110,10 @@ def __init__(
119110

120111
self.model = self._model_options.model_class(**self._model_options.model_kwargs)
121112

122-
# Maintain module map to easily select modules
123-
# We now need to define the module_map in pretrained_model in FinetuningNetwork
124-
# self._module_map = OrderedDict(
125-
# pe_encoders=self.model.encoder_manager,
126-
# pre_nn=self.model.pre_nn,
127-
# pre_nn_edges=self.model.pre_nn_edges,
128-
# gnn=self.model.gnn,
129-
# graph_output_nn=self.model.task_heads.graph_output_nn,
130-
# task_heads=self.model.task_heads.task_heads,
131-
# )
132-
133113
loss_fun = {
134114
self._get_task_key(
135-
task_level=task_heads_kwargs[key]["task_level"],
115+
task_level=task_levels[key],
136116
task=key
137-
# task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key
138117
): value
139118
for key, value in loss_fun.items()
140119
}
@@ -338,7 +317,7 @@ def _general_step(self, batch: Dict[str, Tensor], step_name: str, to_cpu: bool)
338317
preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())}
339318

340319
preds = {
341-
self._get_task_key(task_level=self.task_heads_kwargs[key]["task_level"], task=key): value
320+
self._get_task_key(task_level=self.task_levels[key], task=key): value
342321
for key, value in preds.items()
343322
}
344323
# preds = {k: preds[ii] for ii, k in enumerate(targets_dict.keys())}

tests/test_finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_finetuning_pipeline(self):
6767
metrics = load_metrics(cfg)
6868

6969
predictor = load_predictor(
70-
cfg, model_class, model_kwargs, metrics, accelerator_type, datamodule.task_norms
70+
cfg, model_class, model_kwargs, metrics, datamodule.get_task_levels(), accelerator_type, datamodule.task_norms
7171
)
7272

7373
self.assertEqual(

0 commit comments

Comments
 (0)