-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetuning pipeline #414
Finetuning pipeline #414
Conversation
Codecov Report
@@ Coverage Diff @@
## main #414 +/- ##
==========================================
- Coverage 66.87% 64.74% -2.14%
==========================================
Files 82 89 +7
Lines 7838 8211 +373
==========================================
+ Hits 5242 5316 +74
- Misses 2596 2895 +299
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's my review. It mostly aligns with what we discussed last time
graphium/finetuning/utils.py
Outdated
from copy import deepcopy | ||
|
||
|
||
def modify_cfg_for_finetuning(cfg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be a function within FeedForwardNN
and FullGraphNetwork
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, we can remove the function and replace inside the networks. As discussed in that #411, it might be possible to have a similar function within load_architecture
in graphium/data/_loader.py
but depends on the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @WenkelF, sorry for the delay here!
Thank you for this first implementation!
I left comments whenever something stood out to me, but I'm aware that this PR is still WIP. Sorry if I pointed out some things that you were already planning to change.
It would be super helpful if you could document the main fine-tuning "flow". Furthermore, I would suggest to simplify the process by adding support for one feature at a time, instead of having half-implemented features. This will make understanding, debugging, testing and maintaining the code base a lot easier.
Happy to help next week!
…FullGraphFinetuningNetwork
…ut_nn and gnn; addressing comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget that the objective is to release a working, incomplete version first. Then refining it with more complex fine-tuning possibilities.
qm9: | ||
task_level: graph | ||
out_dim: 19 | ||
hidden_dims: 128 | ||
depth: 2 | ||
activation: relu | ||
last_activation: none | ||
dropout: *dropout | ||
normalization: *normalization | ||
last_normalization: "none" | ||
residual_type: none | ||
tox21: | ||
task_level: graph | ||
out_dim: 12 | ||
hidden_dims: 64 | ||
depth: 2 | ||
activation: relu | ||
last_activation: sigmoid | ||
dropout: *dropout | ||
normalization: *normalization | ||
last_normalization: "none" | ||
residual_type: none | ||
zinc: | ||
task_level: graph | ||
out_dim: 3 | ||
hidden_dims: 32 | ||
depth: 2 | ||
activation: relu | ||
last_activation: none |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't make sense. We should not have any architectural choice from the original pre-trained model in here. Only things that would change.
That way, we can take different pre-trained models that have different hparams/seed and fine-tune them all with the same file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. The configurations are still structured in a way where we have access to both the full config of the pretrained model and the pretraining-related config. And the modify_cfg_for_finetuning function consolidates information to one config.
This will be fixed once we incorporate the new hydra config from #421. We will still need modify_cfg_for_finetuning as of now. Therefore, it could be good waiting for the final version.
graphium/data/datamodule.py
Outdated
try: | ||
if "epoch_sampling_fraction" in args[task].keys(): | ||
args[task].pop("epoch_sampling_fraction") | ||
except: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't even understand the point of having a try
and an if
there. Dict.pop
works even if the key is not available.
But we need to make sure that args[task]
is not used elsewhere, even outside the current function since dict are passed as pointers. We only want to remove epoch_sampling
for the hash key. So I would suggest the following.
try: | |
if "epoch_sampling_fraction" in args[task].keys(): | |
args[task].pop("epoch_sampling_fraction") | |
except: | |
pass | |
args[task] = deepcopy(args[task]) | |
args[task].pop("epoch_sampling_fraction") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this is not an ideal fix. I wanted to investigate the issue a bit more.
We cannot use args[task].pop("epoch_sampling_fraction")
because args[task]
may be of class DatasetProcessingParams
instead of Dict
. In particular, ADMETBenchmarkDataModule
makes use of DatasetProcessingParams
.
The error originates from changes here 4b82ba3, where the line args[task].pop("epoch_sampling_fraction")
was added. It did not cause errors back then because we were using a Dict in all configs (although I see a comment # To be replaced by a new class "DatasetParams"
everywhere it appears).
Will create issue and think about a fix.
graphium/finetuning/finetuning.py
Outdated
|
||
|
||
class GraphFinetuning(BaseFinetuning): | ||
def __init__(self, cfg, train_bn: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to explicitly pass parameters.
def __init__(self, cfg, train_bn: bool = False): | |
def __init__(self, fine-tuning, architecture, module_from_pretrained, ....................., train_bn: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in 76e2ba6
modules = pl_module.model.task_heads.graph_output_nn | ||
elif module == "task_heads": | ||
modules = pl_module.model.task_heads.task_heads | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: raise "Wrong module"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed in 76e2ba6
graphium/finetuning/finetuning.py
Outdated
if module == "pe_encoders": | ||
modules = pl_module.model.encoder_manager | ||
elif module == "pre_nn": | ||
modules = pl_module.model.pre_nn | ||
elif module == "pre_nn_edges": | ||
modules = pl_module.model.pre_nn_edges | ||
elif module == "gnn": | ||
modules = pl_module.model.gnn | ||
elif module == "graph_output_nn": | ||
modules = pl_module.model.task_heads.graph_output_nn | ||
elif module == "task_heads": | ||
modules = pl_module.model.task_heads.task_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would define all these in a dictionary _module_map = {pe_encoders: pl_module.model.encoder_manager, ...}
directly in the __init__
. That way, with inheritance, someone could modified the entries without copy-pasting all the logic.
_module_map
can replace the module_list
you already have.
But instead of a regular dict, using an OrderedDict
would also allow you to say something like: "freeze everything before task_heads" in a very simple way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bad idea to have a FullGraphFinetuningNetwork
that basically copy-pastes most of the functionality of FullGraphNetwork
.
Either use inheritance, or implement the fine-tuning logic directly within FullGraphNetwork
@DomInvivo thanks for your comments. Here is also a quick overview of the updates Updates:
Remarks:
|
@DomInvivo this pull introduces the following:
The finetuning pipeline is maintained separately from existing architectures under
Training is handled by the All methods in graphium/finetuning are implemented such that they are not specific to a pretrained model or finetuning head. This is achieved by requiring the pretrained model to come with a module_map (see, e.g., The new unit test The updates to hydra allow to easily switch between benchmarking and finetuning. Major changes are documented here WenkelF#10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good, great work on this major PR!
A few changes to make, and some comments.
graphium/trainer/predictor.py
Outdated
if "task_heads_kwargs" in model_kwargs.keys(): | ||
task_heads_kwargs = model_kwargs["task_heads_kwargs"] | ||
elif "pretrained_model_kwargs" in model_kwargs.keys(): | ||
# This covers finetuning cases where we finetune from the task_heads | ||
task_heads_kwargs = model_kwargs["pretrained_model_kwargs"]["task_heads_kwargs"] | ||
else: | ||
raise ValueError("incorrect model_kwargs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be here. I think that, if you are using a pre-trained model, you should pass directly model_kwargs["pretrained_model_kwargs"]
into model_kwargs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in a3d4715 as explained below
graphium/trainer/predictor.py
Outdated
task_level=task_heads_kwargs[key]["task_level"], | ||
task=key | ||
# task_level=model_kwargs["task_heads_kwargs"][key]["task_level"], task=key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, the PredictorModule
should be agnostic to the model passed. By having the self._get_task_key
here, it forces a certain architecture in the config which is not very flexible.
I see that this logic was introduced already in the code prior to this PR. If it requires too many changes, let's open a new issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, thanks for pointing out.
We could achieve this by getting the task-specific information (which is only the task level as far as I know) from the datamodule.
Here is how this could be done:
a3d4715
What do you think?
graphium/trainer/predictor.py
Outdated
task_level=task_heads_kwargs[key]["task_level"], | ||
task=key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, don't like how the config structure is imposed. Perhaps task_level
should simply be passed to the PredictorModule
to keep flexibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, this is implemented in a3d4715
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why duplicating the model for CPU and GPU? Models should be agnostic to the training hardware, and to the fine-tuning hardware. You can train on CPU and fine-tune on GPU or IPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree. I only changed from gpu to cpu because github cannot do unit tests on gpu. Should I remove the gpu model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you can remove the gpu model.
@zhiyil-graphcore @s-maddrellmander We'll need your help here to fix the tests for IPU. And ideally, have a test that loads a CPU-trained model onto IPU for finetuning. |
Additional documentation pass and addressing comments from the PR review
@cwognum the bug in the finetuning training is fixed here febdf2d I missed a deepcopy operation when defining the datahashes for the TDC datasets. We include the first 5 rows of the df when generating the hash and the bug reduced the datasets to those 5 rows (molecules) as well. Make sure to remove the TDC datasets from datacache. You can use |
@DomInvivo I added some final improvements 24354ee
(3.) makes it much easier to finetune from modules other than the task_heads without manually re-defining the downstream network. When finetuning from task_heads, it is not needed. |
Merging the IPU tests so the ipu CLI test is in the correct environment for major updates to workflow in #414
@WenkelF - try merging from the main branch, I've made a change to the IPU test CLI that should take into account the changes made in this PR. If that doesn't work let me know. |
@DomInvivo as discussed a first draft for the Finetuning pipeline.
Two possible pipelines:
expts/main_run_finetuning_v1.py
(probably to be removed):expts/main_run_finetuning_v2.py
:Main TODOs: