Skip to content

Commit

Permalink
Merge pull request #424 from datamol-io/speed_test
Browse files Browse the repository at this point in the history
Speed test + largemix mpnn config
  • Loading branch information
DomInvivo authored Aug 2, 2023
2 parents 57399b3 + 2d51bbb commit 6006da5
Show file tree
Hide file tree
Showing 14 changed files with 226 additions and 165 deletions.
4 changes: 3 additions & 1 deletion expts/hydra-configs/accelerator/ipu.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
type: ipu
ipu_config:
- deviceIterations(30) # IPU would require large batches to be ready for the model.
- deviceIterations(60) # IPU would require large batches to be ready for the model.
# 60 for PCQM4mv2
# 30 for largemix
- replicationFactor(16)
# - enableProfiling("graph_analyser") # The folder where the profile will be stored
# - enableExecutableCaching("pop_compiler_cache")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,70 +1,5 @@
# @package _global_

datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
homolumo:
df: null
task_level: "graph"
df_path: graphium/data/PCQM4M/pcqm4mv2.csv
# wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv
# or set path as https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv directly
smiles_col: "cxsmiles"
label_cols: ["homo_lumo_gap"]
# sample_size: 8000 # use sample_size for test
splits_path: graphium/data/PCQM4M/split_dict_v2.pt # Download with `wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
split_names: ["train", "valid", "test-dev"]
# graphium/data/PCQM4Mv2/split_dict.pt
# graphium/data/PCQM4Mv2/pcqm4m_split.csv
# split_val: 0.1
# split_test: 0.1
seed: ${constants.seed}
label_normalization:
method: "normal"

# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/PCQM4Mv2/"
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16

# cache_data_path: .
num_workers: 30 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.

architecture:
model_type: FullGraphMultiTaskNetwork
mup_base_path: null
Expand Down Expand Up @@ -144,78 +79,46 @@ architecture:
last_normalization: "none"
residual_type: none

task_heads:
homolumo:
task_level: graph
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: *dropout
normalization: *normalization
last_normalization: "none"
residual_type: none

#Task-specific
predictor:
metrics_on_progress_bar:
homolumo: []
metrics_on_training_set:
homolumo: ["pearsonr"]
loss_fun:
homolumo: mae_ipu
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor homolumo/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
flag_kwargs:
n_steps: 0 # 1
alpha: 0.0 # 0.01

# Task-specific
metrics:
homolumo:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label
datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
# Featurization
prepare_dict_or_graph: pyg:graph
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: ${constants.datacache_path}
num_workers: 40 # -1 to use all
persistent_workers: False # if use persistent worker at the start of each epoch.
# Using persistent_workers false might make the start of each epoch very long.
featurization:
# OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence),
# 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring',
# 'num_chiral_centers (not included yet)']
atom_property_list_onehot: [atomic-number, group, period, total-valence]
atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring]
# OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring']
edge_property_list: [bond-type-onehot, stereo, in-ring]
add_self_loop: False
explicit_H: False # if H is included
use_bonds_weights: False
pos_encoding_as_features: # encoder dropout 0.18
pos_types:
lap_eigvec:
pos_level: node
pos_type: laplacian_eigvec
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
lap_eigval:
pos_level: node
pos_type: laplacian_eigval
num_pos: 8
normalization: "none" # nomrlization already applied on the eigen vectors
disconnected_comp: True # if eigen values/vector for disconnected graph are included
rw_pos: # use same name as pe_encoder
pos_level: node
pos_type: rw_return_probs
ksteps: 16

trainer:
seed: ${constants.seed}
logger:
save_dir: logs/PCQMv2
name: ${constants.name}
project: PCQMv2_mpnn
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/PCMQ4Mv2/
filename: ${constants.name}
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
2 changes: 1 addition & 1 deletion expts/hydra-configs/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ defaults:

# Specializations
- training/accelerator: ${training}_${accelerator}
- training/model: ${training}_${model}
- training/model: ${training}_${model}
15 changes: 10 additions & 5 deletions expts/hydra-configs/model/gpspp.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# @package _global_

datamodule:
args:
batch_size_training: 32
featurization:
conformer_property_list: [positions_3d]

trainer:
trainer:
accumulate_grad_batches: 2

architecture:
pe_encoders:
encoders:
Expand Down Expand Up @@ -31,8 +41,3 @@ architecture:
num_heads: 32
droppath_rate_attn: 0.0
droppath_rate_ffn: 0.0

datamodule:
args: # Matches that in the test_multitask_datamodule.py case.
featurization:
conformer_property_list: [positions_3d]
8 changes: 8 additions & 0 deletions expts/hydra-configs/model/mpnn.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# @package _global_

datamodule:
args:
batch_size_training: 64

trainer:
trainer:
accumulate_grad_batches: 1

architecture:
gnn:
layer_type: 'pyg:gps'
Expand Down
62 changes: 62 additions & 0 deletions expts/hydra-configs/tasks/pcqm4m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# @package _global_

architecture:
task_heads:
homolumo:
task_level: graph
out_dim: 1
hidden_dims: 256
depth: 2 # Not needed if we have hidden_dims
activation: relu
last_activation: none
dropout: 0.18
normalization: layer_norm
last_normalization: "none"
residual_type: none

#Task-specific
predictor:
metrics_on_progress_bar:
homolumo: []
metrics_on_training_set:
homolumo: ["pearsonr"]
loss_fun:
homolumo: mae_ipu

# Task-specific
metrics:
homolumo:
- name: mae
metric: mae_ipu
target_nan_mask: null
multitask_handling: mean-per-label
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
threshold_kwargs: null
target_nan_mask: null
multitask_handling: mean-per-label

datamodule:
module_type: "MultitaskFromSmilesDataModule"
# module_type: "FakeDataModule" # Option to use generated data
args: # Matches that in the test_multitask_datamodule.py case.
task_specific_args: # To be replaced by a new class "DatasetParams"
homolumo:
df: null
task_level: "graph"
df_path: graphium/data/PCQM4M/pcqm4mv2.csv
# wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv
# or set path as https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/pcqm4mv2.csv directly
smiles_col: "cxsmiles"
label_cols: ["homo_lumo_gap"]
# sample_size: 8000 # use sample_size for test
splits_path: graphium/data/PCQM4M/split_dict_v2.pt # Download with `wget https://storage.googleapis.com/datasets-public-research/PCQM4M/cxsmiles/split_dict_v2.pt`
split_names: ["train", "valid", "test-dev"]
# graphium/data/PCQM4Mv2/split_dict.pt
# graphium/data/PCQM4Mv2/pcqm4m_split.csv
# split_val: 0.1
# split_test: 0.1
seed: ${constants.seed}
label_normalization:
method: "normal"
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ datamodule:
max_num_nodes_per_graph: 30 # valid max nodes: 51, max_edges: 118
max_num_edges_per_graph: 120
# Data handling-related
batch_size_training: 32
batch_size_inference: 16

predictor:
metrics_every_n_train_steps: 1000
metrics_every_n_train_steps: 100
optim_kwargs:
loss_scaling: 1024

trainer:
trainer:
precision: 16-true
accumulate_grad_batches: 2
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ constants:
seed: 42
max_epochs: 100
raise_train_error: true # Whether the code should raise an error if it crashes during training
datacache_path: "/localdata/PCQM4Mv2/"

trainer:
model_checkpoint:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ constants:
seed: 42
max_epochs: 100
raise_train_error: true # Whether the code should raise an error if it crashes during training
datacache_path: "/localdata/PCQM4Mv2/"

trainer:
model_checkpoint:
Expand Down
44 changes: 44 additions & 0 deletions expts/hydra-configs/training/pcqm4m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# @package _global_

predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-4 # warmup can be scheduled using torch_scheduler_kwargs
# weight_decay: 1.e-7
torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: &max_epochs 100
warmup_epochs: 10
verbose: False
scheduler_kwargs:
# monitor: &monitor homolumo/mae/train
# mode: min
# frequency: 1
target_nan_mask: null # null: no mask, 0: 0 mask, ignore: ignore nan values from loss
flag_kwargs:
n_steps: 0 # 1
alpha: 0.0 # 0.01


trainer:
seed: ${constants.seed}
logger:
save_dir: logs/PCQMv2
name: ${constants.name}
project: PCQMv2_mpnn
#early_stopping:
# monitor: *monitor
# min_delta: 0
# patience: 10
# mode: &mode min
model_checkpoint:
dirpath: models_checkpoints/PCMQ4Mv2/
filename: ${constants.name}
#monitor: *monitor
#mode: *mode
save_top_k: 1
every_n_epochs: 100
trainer:
max_epochs: *max_epochs
min_epochs: 1
check_val_every_n_epoch: 20
Loading

0 comments on commit 6006da5

Please sign in to comment.