From 225763855827ee3ecb81d6db77ecb7ca86ecd2d1 Mon Sep 17 00:00:00 2001 From: Zhiyi Li <86362692+zhiyil-graphcore@users.noreply.github.com> Date: Thu, 29 Jun 2023 10:28:30 +0100 Subject: [PATCH] some fixes: keys() and use normalization flag on val and test for baseline configs (#367) * minor fixes to test the toymix datasets * add normalize_val_test flag to baseline configs * attempting to fix tests and make changes to gey keys so that it works on both object and dict * fix a few other tests when run locally * changes on the keys and keys() * baseline config changes * minor changes after testing --- expts/main_run_multitask.py | 2 +- .../baseline/config_small_gcn_baseline.yaml | 13 + .../baseline/config_small_gin_baseline.yaml | 323 +---------------- .../baseline/config_small_gine_baseline.yaml | 325 +----------------- .../debug/config_large_gcn_debug.yaml | 8 +- graphium/data/collate.py | 11 +- graphium/data/datamodule.py | 4 +- graphium/data/utils.py | 7 + graphium/features/featurizer.py | 8 + graphium/ipu/ipu_dataloader.py | 5 +- graphium/ipu/ipu_wrapper.py | 5 +- graphium/nn/architectures/encoder_manager.py | 3 +- .../nn/architectures/global_architectures.py | 4 +- graphium/nn/pyg_layers/gps_pyg.py | 4 +- tests/test_dataset.py | 15 +- 15 files changed, 67 insertions(+), 670 deletions(-) diff --git a/expts/main_run_multitask.py b/expts/main_run_multitask.py index 569c9b4be..42033f851 100644 --- a/expts/main_run_multitask.py +++ b/expts/main_run_multitask.py @@ -36,7 +36,7 @@ # CONFIG_FILE = "expts/configs/config_mpnn_10M_pcqm4m.yaml" # CONFIG_FILE = "expts/neurips2023_configs/config_debug.yaml" # CONFIG_FILE = "expts/neurips2023_configs/config_large_mpnn.yaml" -CONFIG_FILE = "expts/neurips2023_configs/debug/config_large_gcn_debug.yaml" +CONFIG_FILE = "expts/neurips2023_configs/debug/config_small_gcn_debug.yaml" # CONFIG_FILE = "expts/neurips2023_configs/config_large_gin.yaml" # CONFIG_FILE = "expts/neurips2023_configs/config_large_gcn.yaml" # CONFIG_FILE = "expts/neurips2023_configs/config_large_gine.yaml" diff --git a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml index 2f0f974c3..975a4e992 100644 --- a/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml @@ -37,6 +37,17 @@ accelerator: - _Popart.set("defaultBufferingDepth", 128) - Precision.enableStochasticRounding(True) +# accelerator: +# type: cpu # cpu or ipu or gpu +# config_override: +# datamodule: +# batch_size_training: 64 +# batch_size_inference: 256 +# trainer: +# trainer: +# precision: 32 +# accumulate_grad_batches: 1 + datamodule: module_type: "MultitaskFromSmilesDataModule" # module_type: "FakeDataModule" # Option to use generated data @@ -54,6 +65,7 @@ datamodule: seed: *seed task_level: graph label_normalization: + normalize_val_test: True method: "normal" tox21: @@ -80,6 +92,7 @@ datamodule: seed: *seed task_level: graph label_normalization: + normalize_val_test: True method: "normal" # Featurization diff --git a/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml index 87e020ee2..bc96d1057 100644 --- a/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gin_baseline.yaml @@ -1,342 +1,21 @@ # Testing the gin model with the PCQMv2 dataset on IPU. constants: name: &name neurips2023_small_data_gin + config_override: "expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml" seed: &seed 1000 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -accelerator: - type: ipu # cpu or ipu or gpu - config_override: - datamodule: - args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 80 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 80 - # Data handling-related - batch_size_training: 5 - batch_size_inference: 5 - predictor: - optim_kwargs: - loss_scaling: 1024 - trainer: - trainer: - precision: 16 - accumulate_grad_batches: 4 - - ipu_config: - - deviceIterations(5) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - -datamodule: - module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data - args: # Matches that in the test_multitask_datamodule.py case. - task_specific_args: # To be replaced by a new class "DatasetParams" - qm9: - df: null - df_path: data/neurips2023/small-dataset/qm9.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"] - # sample_size: 2000 # use sample_size for test - splits_path: data/neurips2023/small-dataset/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt` - seed: *seed - task_level: graph - label_normalization: - method: "normal" - - tox21: - df: null - df_path: data/neurips2023/small-dataset/Tox21-7k-12-labels.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] - # sample_size: 2000 # use sample_size for test - splits_path: data/neurips2023/small-dataset/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt` - seed: *seed - task_level: graph - - zinc: - df: null - df_path: data/neurips2023/small-dataset/ZINC12k.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["SA", "logp", "score"] - # sample_size: 2000 # use sample_size for test - splits_path: data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt` - seed: *seed - task_level: graph - label_normalization: - method: "normal" - - # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - processed_graph_data_path: "../datacache/neurips2023-small/" - featurization: - # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), - # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', - # 'num_chiral_centers (not included yet)'] - atom_property_list_onehot: [atomic-number, group, period, total-valence] - atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring] - # OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring'] - edge_property_list: [bond-type-onehot, stereo, in-ring] - add_self_loop: False - explicit_H: False # if H is included - use_bonds_weights: False - pos_encoding_as_features: # encoder dropout 0.18 - pos_types: - lap_eigvec: - pos_level: node - pos_type: laplacian_eigvec - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - lap_eigval: - pos_level: node - pos_type: laplacian_eigval - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - rw_pos: # use same name as pe_encoder - pos_level: node - pos_type: rw_return_probs - ksteps: 16 - - # cache_data_path: . - num_workers: 30 # -1 to use all - persistent_workers: False # if use persistent worker at the start of each epoch. - # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" - architecture: - model_type: FullGraphMultiTaskNetwork - mup_base_path: null - pre_nn: # Set as null to avoid a pre-nn network - out_dim: 64 - hidden_dims: 256 - depth: 2 - activation: relu - last_activation: none - dropout: &dropout 0.1 - normalization: &normalization layer_norm - last_normalization: *normalization - residual_type: none - - pre_nn_edges: null # Set as null to avoid a pre-nn network - - pe_encoders: - out_dim: 32 - pool: "sum" #"mean" "max" - last_norm: None #"batch_norm", "layer_norm" - encoders: #la_pos | rw_pos - la_pos: # Set as null to avoid a pre-nn network - encoder_type: "laplacian_pe" - input_keys: ["laplacian_eigvec", "laplacian_eigval"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - model_type: 'DeepSet' #'Transformer' or 'DeepSet' - num_layers: 2 - num_layers_post: 1 # Num. layers to apply after pooling - dropout: 0.1 - first_normalization: "none" #"batch_norm" or "layer_norm" - rw_pos: - encoder_type: "mlp" - input_keys: ["rw_return_probs"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - num_layers: 2 - dropout: 0.1 - normalization: "layer_norm" #"batch_norm" or "layer_norm" - first_normalization: "layer_norm" #"batch_norm" or "layer_norm" - - - gnn: # Set as null to avoid a post-nn network in_dim: 64 # or otherwise the correct value out_dim: &gnn_dim 96 hidden_dims: *gnn_dim - depth: 4 - activation: gelu - last_activation: none - dropout: 0.1 - normalization: "layer_norm" - last_normalization: *normalization - residual_type: simple - virtual_node: 'none' layer_type: 'pyg:gin' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps - layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1 - - - graph_output_nn: - graph: - pooling: [sum] - out_dim: *gnn_dim - hidden_dims: *gnn_dim - depth: 1 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - - task_heads: - qm9: - task_level: graph - out_dim: 19 - hidden_dims: 256 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - tox21: - task_level: graph - out_dim: 12 - hidden_dims: 64 - depth: 2 - activation: relu - last_activation: sigmoid - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - zinc: - task_level: graph - out_dim: 3 - hidden_dims: 32 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - -#Task-specific -predictor: - metrics_on_progress_bar: - qm9: ["mae"] - tox21: ["auroc"] - zinc: ["mae"] - loss_fun: - qm9: mae_ipu - tox21: bce_ipu - zinc: mae_ipu - random_seed: *seed - optim_kwargs: - lr: 1.e-3 # warmup can be scheduled using torch_scheduler_kwargs - # weight_decay: 1.e-7 - torch_scheduler_kwargs: - module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 300 - warmup_epochs: 10 - verbose: False - scheduler_kwargs: - # monitor: &monitor qm9/mae/train - # mode: min - # frequency: 1 - target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label - multitask_handling: flatten # flatten, mean-per-label - -# Task-specific -metrics: - qm9: &qm9_metrics - - name: mae - metric: mae_ipu - target_nan_mask: null - multitask_handling: flatten - threshold_kwargs: null - - name: pearsonr - metric: pearsonr_ipu - threshold_kwargs: null - target_nan_mask: null - multitask_handling: mean-per-label - - name: r2_score - metric: r2_score_ipu - target_nan_mask: null - multitask_handling: mean-per-label - threshold_kwargs: null - tox21: - - name: auroc - metric: auroc_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: avpr - metric: average_precision_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: f1 > 0.5 - metric: f1 - multitask_handling: flatten - target_to_int: True - num_classes: 2 - average: micro - threshold_kwargs: &threshold_05 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: True - - name: precision > 0.5 - metric: precision - multitask_handling: flatten - average: micro - threshold_kwargs: *threshold_05 - zinc: *qm9_metrics trainer: seed: *seed logger: - save_dir: logs/neurips2023-small/ name: *name project: *name - #early_stopping: - # monitor: *monitor - # min_delta: 0 - # patience: 10 - # mode: &mode min model_checkpoint: dirpath: models_checkpoints/neurips2023-small-gin/ filename: *name - #monitor: *monitor - #mode: *mode - save_top_k: 1 - every_n_epochs: 100 - trainer: - max_epochs: *max_epochs - min_epochs: 1 - check_val_every_n_epoch: 20 diff --git a/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml b/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml index 05a87aa7d..431235bb4 100644 --- a/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml +++ b/expts/neurips2023_configs/baseline/config_small_gine_baseline.yaml @@ -1,350 +1,31 @@ # Testing the gine model with the PCQMv2 dataset on IPU. constants: name: &name neurips2023_small_data_gine + config_override: "expts/neurips2023_configs/baseline/config_small_gcn_baseline.yaml" seed: &seed 1000 - raise_train_error: true # Whether the code should raise an error if it crashes during training - -accelerator: - type: ipu # cpu or ipu or gpu - config_override: - datamodule: - args: - ipu_dataloader_training_opts: - mode: async - max_num_nodes_per_graph: 44 # train max nodes: 20, max_edges: 54 - max_num_edges_per_graph: 80 - ipu_dataloader_inference_opts: - mode: async - max_num_nodes_per_graph: 44 # valid max nodes: 51, max_edges: 118 - max_num_edges_per_graph: 80 - # Data handling-related - batch_size_training: 5 - batch_size_inference: 5 - predictor: - optim_kwargs: - loss_scaling: 1024 - trainer: - trainer: - precision: 16 - accumulate_grad_batches: 4 - - ipu_config: - - deviceIterations(5) # IPU would require large batches to be ready for the model. - - replicationFactor(16) - # - enableProfiling("graph_analyser") # The folder where the profile will be stored - # - enableExecutableCaching("pop_compiler_cache") - - TensorLocations.numIOTiles(128) - - _Popart.set("defaultBufferingDepth", 128) - - Precision.enableStochasticRounding(True) - -# accelerator: -# type: cpu # cpu or ipu or gpu -# config_override: -# datamodule: -# batch_size_training: 64 -# batch_size_inference: 256 -# trainer: -# trainer: -# precision: 32 -# accumulate_grad_batches: 1 - -datamodule: - module_type: "MultitaskFromSmilesDataModule" - # module_type: "FakeDataModule" # Option to use generated data - args: # Matches that in the test_multitask_datamodule.py case. - task_specific_args: # To be replaced by a new class "DatasetParams" - qm9: - df: null - df_path: data/neurips2023/small-dataset/qm9.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["A", "B", "C", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "u0", "u298", "h298", "g298", "cv", "u0_atom", "u298_atom", "h298_atom", "g298_atom"] - # sample_size: 2000 # use sample_size for test - splits_path: data/neurips2023/small-dataset/qm9_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/qm9_random_splits.pt` - seed: *seed - task_level: graph - label_normalization: - method: "normal" - - tox21: - df: null - df_path: data/neurips2023/small-dataset/Tox21-7k-12-labels.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21-7k-12-labels.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"] - # sample_size: 2000 # use sample_size for test - splits_path: data/neurips2023/small-dataset/Tox21_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/Tox21_random_splits.pt` - seed: *seed - task_level: graph - - zinc: - df: null - df_path: data/neurips2023/small-dataset/ZINC12k.csv.gz - # wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k.csv.gz - # or set path as the URL directly - smiles_col: "smiles" - label_cols: ["SA", "logp", "score"] - # sample_size: 2000 # use sample_size for test - splits_path: data/neurips2023/small-dataset/ZINC12k_random_splits.pt # Download with `wget https://storage.googleapis.com/graphium-public/datasets/neurips_2023/Small-dataset/ZINC12k_random_splits.pt` - seed: *seed - task_level: graph - label_normalization: - method: "normal" - - # Featurization - prepare_dict_or_graph: pyg:graph - featurization_n_jobs: 30 - featurization_progress: True - featurization_backend: "loky" - processed_graph_data_path: "../datacache/neurips2023-small/" - featurization: - # OGB: ['atomic_num', 'degree', 'possible_formal_charge', 'possible_numH' (total-valence), - # 'possible_number_radical_e', 'possible_is_aromatic', 'possible_is_in_ring', - # 'num_chiral_centers (not included yet)'] - atom_property_list_onehot: [atomic-number, group, period, total-valence] - atom_property_list_float: [degree, formal-charge, radical-electron, aromatic, in-ring] - # OGB: ['possible_bond_type', 'possible_bond_stereo', 'possible_is_in_ring'] - edge_property_list: [bond-type-onehot, stereo, in-ring] - add_self_loop: False - explicit_H: False # if H is included - use_bonds_weights: False - pos_encoding_as_features: # encoder dropout 0.18 - pos_types: - lap_eigvec: - pos_level: node - pos_type: laplacian_eigvec - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - lap_eigval: - pos_level: node - pos_type: laplacian_eigval - num_pos: 8 - normalization: "none" # nomrlization already applied on the eigen vectors - disconnected_comp: True # if eigen values/vector for disconnected graph are included - rw_pos: # use same name as pe_encoder - pos_level: node - pos_type: rw_return_probs - ksteps: 16 - - # cache_data_path: . - num_workers: 30 # -1 to use all - persistent_workers: False # if use persistent worker at the start of each epoch. - # Using persistent_workers false might make the start of each epoch very long. - featurization_backend: "loky" - architecture: - model_type: FullGraphMultiTaskNetwork - mup_base_path: null - pre_nn: # Set as null to avoid a pre-nn network - out_dim: 64 - hidden_dims: 256 - depth: 2 - activation: relu - last_activation: none - dropout: &dropout 0.1 - normalization: &normalization layer_norm - last_normalization: *normalization - residual_type: none - pre_nn_edges: # Set as null to avoid a pre-nn network out_dim: 32 hidden_dims: 128 depth: 2 activation: relu last_activation: none - dropout: *dropout - normalization: *normalization + dropout: 0.1 + normalization: &normalization layer_norm last_normalization: *normalization residual_type: none - pe_encoders: - out_dim: 32 - pool: "sum" #"mean" "max" - last_norm: None #"batch_norm", "layer_norm" - encoders: #la_pos | rw_pos - la_pos: # Set as null to avoid a pre-nn network - encoder_type: "laplacian_pe" - input_keys: ["laplacian_eigvec", "laplacian_eigval"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - model_type: 'DeepSet' #'Transformer' or 'DeepSet' - num_layers: 2 - num_layers_post: 1 # Num. layers to apply after pooling - dropout: 0.1 - first_normalization: "none" #"batch_norm" or "layer_norm" - rw_pos: - encoder_type: "mlp" - input_keys: ["rw_return_probs"] - output_keys: ["feat"] - hidden_dim: 64 - out_dim: 32 - num_layers: 2 - dropout: 0.1 - normalization: "layer_norm" #"batch_norm" or "layer_norm" - first_normalization: "layer_norm" #"batch_norm" or "layer_norm" - - - gnn: # Set as null to avoid a post-nn network out_dim: &gnn_dim 96 hidden_dims: *gnn_dim - depth: 4 - activation: gelu - last_activation: none - dropout: 0.1 - normalization: "layer_norm" - last_normalization: *normalization - residual_type: simple - virtual_node: 'none' layer_type: 'pyg:gine' #pyg:gine #'pyg:gps' # pyg:gated-gcn, pyg:gine,pyg:gps - layer_kwargs: null # Parameters for the model itself. You could define dropout_attn: 0.1 - - - graph_output_nn: - graph: - pooling: [sum] - out_dim: *gnn_dim - hidden_dims: *gnn_dim - depth: 1 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - - task_heads: - qm9: - task_level: graph - out_dim: 19 - hidden_dims: 256 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - tox21: - task_level: graph - out_dim: 12 - hidden_dims: 64 - depth: 2 - activation: relu - last_activation: sigmoid - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - zinc: - task_level: graph - out_dim: 3 - hidden_dims: 32 - depth: 2 - activation: relu - last_activation: none - dropout: *dropout - normalization: *normalization - last_normalization: "none" - residual_type: none - -#Task-specific -predictor: - metrics_on_progress_bar: - qm9: ["mae"] - tox21: ["auroc"] - zinc: ["mae"] - loss_fun: - qm9: mae_ipu - tox21: bce_ipu - zinc: mae_ipu - random_seed: *seed - optim_kwargs: - lr: 1.e-3 # warmup can be scheduled using torch_scheduler_kwargs - # weight_decay: 1.e-7 - torch_scheduler_kwargs: - module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 300 - warmup_epochs: 10 - verbose: False - scheduler_kwargs: - # monitor: &monitor qm9/mae/train - # mode: min - # frequency: 1 - target_nan_mask: null # null: no mask, 0: 0 mask, ignore-flatten, ignore-mean-per-label - multitask_handling: flatten # flatten, mean-per-label - -# Task-specific -metrics: - qm9: &qm9_metrics - - name: mae - metric: mae_ipu - target_nan_mask: null - multitask_handling: flatten - threshold_kwargs: null - - name: pearsonr - metric: pearsonr_ipu - threshold_kwargs: null - target_nan_mask: null - multitask_handling: mean-per-label - - name: r2_score - metric: r2_score_ipu - target_nan_mask: null - multitask_handling: mean-per-label - threshold_kwargs: null - tox21: - - name: auroc - metric: auroc_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: avpr - metric: average_precision_ipu - task: binary - multitask_handling: mean-per-label - threshold_kwargs: null - - name: f1 > 0.5 - metric: f1 - multitask_handling: flatten - target_to_int: True - num_classes: 2 - average: micro - threshold_kwargs: &threshold_05 - operator: greater - threshold: 0.5 - th_on_preds: True - th_on_target: True - - name: precision > 0.5 - metric: precision - multitask_handling: flatten - average: micro - threshold_kwargs: *threshold_05 - zinc: *qm9_metrics trainer: seed: *seed logger: - save_dir: logs/neurips2023-small/ name: *name project: *name - #early_stopping: - # monitor: *monitor - # min_delta: 0 - # patience: 10 - # mode: &mode min model_checkpoint: dirpath: models_checkpoints/neurips2023-small-gine/ filename: *name - #monitor: *monitor - #mode: *mode - save_top_k: 1 - every_n_epochs: 100 - trainer: - max_epochs: *max_epochs - min_epochs: 1 - check_val_every_n_epoch: 20 diff --git a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml index b2625d90f..e95215115 100644 --- a/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml +++ b/expts/neurips2023_configs/debug/config_large_gcn_debug.yaml @@ -330,8 +330,8 @@ predictor: # weight_decay: 1.e-7 torch_scheduler_kwargs: module_type: WarmUpLinearLR - max_num_epochs: &max_epochs 20 - warmup_epochs: 10 + max_num_epochs: &max_epochs 10 + warmup_epochs: 5 verbose: False scheduler_kwargs: # monitor: &monitor qm9/mae/train @@ -405,9 +405,9 @@ trainer: # monitor: *monitor # mode: *mode # save_top_k: 1 - every_n_epochs: 10 + every_n_epochs: 5 save_last: True trainer: max_epochs: *max_epochs min_epochs: 1 - check_val_every_n_epoch: 20 + check_val_every_n_epoch: 10 diff --git a/graphium/data/collate.py b/graphium/data/collate.py index 2a64763d1..933c7fe39 100644 --- a/graphium/data/collate.py +++ b/graphium/data/collate.py @@ -11,6 +11,7 @@ from graphium.features import GraphDict, to_dense_array from graphium.utils.packing import fast_packing, get_pack_sizes, node_to_pack_indices_mask from loguru import logger +from graphium.data.utils import get_keys def graphium_collate_fn( @@ -130,7 +131,7 @@ def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_size_per_pa pyg_batch = [] for pyg_graph in pyg_graphs: - for pyg_key in pyg_graph.keys: + for pyg_key in get_keys(pyg_graph): tensor = pyg_graph[pyg_key] # Convert numpy/scipy to Pytorch @@ -198,7 +199,7 @@ def collate_pyg_graph_labels(pyg_labels: List[Data]): """ pyg_batch = [] for pyg_label in pyg_labels: - for pyg_key in set(pyg_label.keys) - set(["x", "edge_index"]): + for pyg_key in set(get_keys(pyg_label)) - set(["x", "edge_index"]): tensor = pyg_label[pyg_key] # Convert numpy/scipy to Pytorch if isinstance(tensor, (ndarray, spmatrix)): @@ -255,14 +256,14 @@ def collate_labels( labels_size_dict[task] = labels_size_dict[task][1:] elif not task.startswith("graph_"): labels_size_dict[task] = [1] - - empty_task_labels = set(labels_size_dict.keys()) - set(this_label.keys) + label_keys_set = set(get_keys(this_label)) + empty_task_labels = set(labels_size_dict.keys()) - label_keys_set for task in empty_task_labels: labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) dtype = labels_dtype_dict[task] this_label[task] = torch.full([*labels_size_dict[task]], torch.nan, dtype=dtype) - for task in set(this_label.keys) - set(["x", "edge_index"]) - empty_task_labels: + for task in label_keys_set - set(["x", "edge_index"]) - empty_task_labels: labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) if not isinstance(this_label[task], (torch.Tensor)): diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index da103208b..d598fbb40 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -10,6 +10,7 @@ import gc from rdkit import Chem import re +from graphium.data.utils import get_keys from loguru import logger import fsspec @@ -1666,8 +1667,7 @@ def in_dims(self): # get list of all keys corresponding to positional encoding pe_dim_dict = {} - g_keys = graph.keys - + g_keys = get_keys(graph) # ignore the normal keys for node feat and edge feat etc. for key in g_keys: prop = graph.get(key, None) diff --git a/graphium/data/utils.py b/graphium/data/utils.py index d3e7faeb7..6eff7c25f 100644 --- a/graphium/data/utils.py +++ b/graphium/data/utils.py @@ -98,3 +98,10 @@ def download_graphium_dataset( dataset_path_destination = dataset_path_destination.split(".")[0] return dataset_path_destination + + +def get_keys(pyg_data): + if isinstance(type(pyg_data).keys, property): + return pyg_data.keys + else: + return pyg_data.keys() diff --git a/graphium/features/featurizer.py b/graphium/features/featurizer.py index 8f3cab189..66f241663 100644 --- a/graphium/features/featurizer.py +++ b/graphium/features/featurizer.py @@ -886,6 +886,14 @@ def __init__( # default_dic.update(edata) super().__init__(default_dic) + @property + def keys(self): + return list(super().keys()) + + @property + def values(self): + return list(super().self.values()) + def make_pyg_graph(self, **kwargs) -> Data: """ Convert the current dictionary of parameters, containing an adjacency matrix with node/edge data diff --git a/graphium/ipu/ipu_dataloader.py b/graphium/ipu/ipu_dataloader.py index de16d3392..92a777f25 100644 --- a/graphium/ipu/ipu_dataloader.py +++ b/graphium/ipu/ipu_dataloader.py @@ -11,6 +11,7 @@ from torch_geometric.data import Data, Batch, Dataset from torch_geometric.transforms import BaseTransform +from graphium.data.utils import get_keys from graphium.ipu.ipu_utils import import_poptorch from graphium.utils.packing import ( fast_packing, @@ -155,9 +156,9 @@ def __call__( out_batch = {} # Stack tensors in the first dimension to allow IPUs to differentiate between local and global graph + all_keys = get_keys(all_batches[0]["labels"]) out_batch["labels"] = { - key: torch.stack([this_batch["labels"][key] for this_batch in all_batches], 0) - for key in all_batches[0]["labels"].keys + key: torch.stack([this_batch["labels"][key] for this_batch in all_batches], 0) for key in all_keys } out_graphs = [this_batch["features"] for this_batch in all_batches] stacked_features = deepcopy(out_graphs[0]) diff --git a/graphium/ipu/ipu_wrapper.py b/graphium/ipu/ipu_wrapper.py index 59b42f82d..de62653b4 100644 --- a/graphium/ipu/ipu_wrapper.py +++ b/graphium/ipu/ipu_wrapper.py @@ -15,6 +15,7 @@ from loguru import logger import functools import collections +from graphium.data.utils import get_keys poptorch = import_poptorch() @@ -33,7 +34,7 @@ def sortedTensorKeys(struct: BaseData) -> Iterable[str]: Find all the keys that map to a tensor value in struct. The keys are returned in sorted order. """ - all_keys = sorted(struct.keys) + all_keys = sorted(get_keys(struct)) def isTensor(k: str) -> bool: return isinstance(struct[k], torch.Tensor) @@ -57,7 +58,7 @@ def reconstruct(self, original_structure: BaseData, tensor_iterator: Iterable[Te tensor_keys = self.sortedTensorKeys(original_structure) kwargs = {k: next(tensor_iterator) for k in tensor_keys} - for k in original_structure.keys: + for k in get_keys(original_structure): if k not in kwargs: # copy non-tensor properties to the new instance kwargs[k] = original_structure[k] diff --git a/graphium/nn/architectures/encoder_manager.py b/graphium/nn/architectures/encoder_manager.py index 6d53fa0a0..e3e48aeba 100644 --- a/graphium/nn/architectures/encoder_manager.py +++ b/graphium/nn/architectures/encoder_manager.py @@ -9,6 +9,7 @@ from torch import Tensor, nn import torch +from graphium.data.utils import get_keys from graphium.nn.encoders import ( laplace_pos_encoder, mlp_encoder, @@ -169,7 +170,7 @@ def forward(self, g: Batch) -> Batch: # If the key is already present, concatenate the pe_pooled to the pre-existing feature. for pe_key, this_pe in pe_pooled.items(): feat = this_pe - if pe_key in g.keys: + if pe_key in get_keys(g): feat = torch.cat((feat, g[pe_key]), dim=-1) g[pe_key] = feat return g diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index adae4c013..979e75534 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -12,6 +12,7 @@ from torch_geometric.data import Data # graphium imports +from graphium.data.utils import get_keys from graphium.nn.base_layers import FCLayer, get_activation, get_norm from graphium.nn.architectures.encoder_manager import EncoderManager from graphium.nn.pyg_layers import VirtualNodePyg, parse_pooling_layer_pyg @@ -1152,7 +1153,8 @@ def forward(self, g: Batch) -> Tensor: g["feat"] = g["feat"] e = None - if "edge_feat" in g.keys: + + if "edge_feat" in get_keys(g): g["edge_feat"] = g["edge_feat"] # Run the pre-processing network on node features diff --git a/graphium/nn/pyg_layers/gps_pyg.py b/graphium/nn/pyg_layers/gps_pyg.py index 9e514fcd5..99ba0032c 100644 --- a/graphium/nn/pyg_layers/gps_pyg.py +++ b/graphium/nn/pyg_layers/gps_pyg.py @@ -13,6 +13,7 @@ PNAMessagePassingPyg, MPNNPlusPyg, ) +from graphium.data.utils import get_keys from graphium.utils.decorators import classproperty from graphium.ipu.to_dense_batch import ( to_dense_batch, @@ -289,7 +290,8 @@ def _use_packing(self, batch: Batch) -> bool: """ Check if we should use packing for the batch of graphs. """ - return "pack_from_node_idx" in batch.keys and "pack_attn_mask" in batch.keys + batch_keys = get_keys(batch) + return "pack_from_node_idx" in batch_keys and "pack_attn_mask" in batch_keys def _to_dense_batch( self, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 67e861b56..1549efd82 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ from graphium.data import load_micro_zinc from graphium.data.dataset import SingleTaskDataset, MultitaskDataset from graphium.data.smiles_transform import smiles_to_unique_mol_ids +from graphium.data.utils import get_keys class Test_Multitask_Dataset(ut.TestCase): @@ -137,19 +138,19 @@ def test_multitask_dataset_case_2(self): for i, id in enumerate(multitask_microzinc.mol_ids): if mol_id == id: found_idx = i - + multitask_microzinc_labels = get_keys(multitask_microzinc.labels[found_idx]) if task == "SA": self.assertEqual(label_SA, multitask_microzinc.labels[found_idx]["SA"]) - self.assertFalse("score" in multitask_microzinc.labels[found_idx].keys) - self.assertFalse("logp" in multitask_microzinc.labels[found_idx].keys) + self.assertFalse("score" in multitask_microzinc_labels) + self.assertFalse("logp" in multitask_microzinc_labels) elif task == "logp": self.assertEqual(label_logp, multitask_microzinc.labels[found_idx]["logp"]) - self.assertFalse("score" in multitask_microzinc.labels[found_idx].keys) - self.assertFalse("SA" in multitask_microzinc.labels[found_idx].keys) + self.assertFalse("score" in multitask_microzinc_labels) + self.assertFalse("SA" in multitask_microzinc_labels) elif task == "score": self.assertEqual(label_score, multitask_microzinc.labels[found_idx]["score"]) - self.assertFalse("SA" in multitask_microzinc.labels[found_idx].keys) - self.assertFalse("logp" in multitask_microzinc.labels[found_idx].keys) + self.assertFalse("SA" in multitask_microzinc_labels) + self.assertFalse("logp" in multitask_microzinc_labels) def test_multitask_dataset_case_3(self): """Case: Different tasks, but with semi-intersection (some smiles unique per task, some intersect)