Skip to content

Commit

Permalink
Merge pull request #433 from WenkelF/finetuning
Browse files Browse the repository at this point in the history
Adding unit test for finetuning from gnn
  • Loading branch information
DomInvivo authored Aug 16, 2023
2 parents 71182f4 + 46c166b commit 6c0787c
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 45 deletions.
141 changes: 141 additions & 0 deletions graphium/config/dummy_finetuning_from_gnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Here, we are finetuning a FullGraphMultitaskNetwork
# trained on ToyMix. We finetune from the gnn on the
# TDC dataset lipophilicity_astraceneca

# Here are the changes to the architecture:
# Change gnn:
# depth: 4 -> 4 - 2 + 3 = 5
#
# Keep modules after gnn and apply modifications
# graph_output_nn/graph:
# pooling: sum -> mean
# depth: 1 -> 2
# task_heads/zinc:
# new_sub_module: lipophilicity_astrazeneca
# out_dim: 3 -> 1
#
# Finetuning training:
# unfreeze one additional layer of pretrained gnn
# after 1 epochs, unfreeze all layers


###################################################
########### How to combine information ###########
###################################################


###########################
### FINETUNING-SPECIFIC ###
###########################

finetuning:
# New task
task: lipophilicity_astrazeneca
level: graph

# Pretrained model
pretrained_model_name: dummy-pretrained-model
finetuning_module: gnn

# Changes to finetuning_module
drop_depth: 2
added_depth: 3

keep_modules_after_finetuning_module: # optional
graph_output_nn/graph:
pooling: [mean]
depth: 2
task_heads/zinc:
new_sub_module: lipophilicity_astrazeneca
out_dim: 1

# Finetuning training
unfreeze_pretrained_depth: 1
epoch_unfreeze_all: 1

constants:
seed: 42
max_epochs: 2

accelerator:
float32_matmul_precision: medium
type: cpu

predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5
scheduler_kwargs: null
target_nan_mask: null
multitask_handling: flatten # flatten, mean-per-label

torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: 2
warmup_epochs: 1
verbose: False

metrics_on_progress_bar:
lipophilicity_astrazeneca: ["mae"]
loss_fun:
lipophilicity_astrazeneca: mae

metrics:
lipophilicity_astrazeneca:
- name: mae
metric: mae
target_nan_mask: null
multitask_handling: flatten
threshold_kwargs: null
- name: spearman
metric: spearmanr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: pearson
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
- name: r2_score
metric: r2
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null

trainer:
seed: ${constants.seed}
trainer:
precision: 32
max_epochs: 2
min_epochs: 1
check_val_every_n_epoch: 1
accumulate_grad_batches: 1

##################
### DATAMODULE ###
##################

datamodule:

### FROM FINETUNING ###

module_type: "ADMETBenchmarkDataModule"
args:
# TDC specific
tdc_benchmark_names: [lipophilicity_astrazeneca]
tdc_train_val_seed: ${constants.seed}

batch_size_training: 200
batch_size_inference: 200
featurization_n_jobs: 0
num_workers: 0

prepare_dict_or_graph: pyg:graph
featurization_progress: True
featurization_backend: "loky"
persistent_workers: False




Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
# (graph-level) on the TDC dataset lipophilicity_astraceneca

# Here are the changes to the architecture:
# Change zinc task-head:
# depth: 2 -> 2 - 1 + 2 = 3
# out_dim: 3 -> 8
#
# Change zinc task-head:
# depth: 2 -> 2 - 1 + 2 = 3
# out_dim: 3 -> 8
# Add finetuning head
# model_type: FeedForwardNN
# out_dim: 1
# hidden_dims: 8
# depth: 2
#
# Add finetuning head
# model_type: FeedForwardNN
# out_dim: 1
# hidden_dims: 8
# depth: 2
# Finetuning training:
# after 1 epochs, unfreeze all layers


###################################################
Expand Down Expand Up @@ -55,11 +57,11 @@ finetuning:

# Finetuning training
unfreeze_pretrained_depth: 0
epoch_unfreeze_all: 2
epoch_unfreeze_all: 1

constants:
seed: 42
max_epochs: 3
max_epochs: 2

accelerator:
float32_matmul_precision: medium
Expand All @@ -75,7 +77,7 @@ predictor:

torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: 3
max_num_epochs: 2
warmup_epochs: 1
verbose: False

Expand Down Expand Up @@ -111,7 +113,7 @@ trainer:
seed: ${constants.seed}
trainer:
precision: 32
max_epochs: 3
max_epochs: 2
min_epochs: 1
check_val_every_n_epoch: 1
accumulate_grad_batches: 1
Expand All @@ -138,7 +140,6 @@ datamodule:
prepare_dict_or_graph: pyg:graph
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
persistent_workers: False


Expand Down
Loading

0 comments on commit 6c0787c

Please sign in to comment.