Skip to content

Commit

Permalink
Merge pull request #444 from datamol-io/baselines
Browse files Browse the repository at this point in the history
Baselines for LargeMix
  • Loading branch information
DomInvivo authored Aug 26, 2023
2 parents f96c1b6 + d1e8cb9 commit 8358a0d
Show file tree
Hide file tree
Showing 33 changed files with 728 additions and 293 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
*.datacache.gz
lightning_logs/
logs/
multirun/
hparam-search-results/
models_checkpoints/
outputs/
out/
Expand All @@ -39,7 +41,8 @@ graphium/data/neurips2023/dummy-dataset/
graphium/data/make_data_splits/*.csv*
graphium/data/make_data_splits/*.pt*
graphium/data/make_data_splits/*.parquet*

*.csv.gz
*.pt

# Others
expts_untracked/
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ However, when working with larger datasets, it is recommended to perform data pr
The following command-line will prepare the data and cache it, then use it to train a model.
```bash
# First prepare the data and cache it in `path_to_cached_data`
graphium-prepare-data datamodule.args.processed_graph_data_path=[path_to_cached_data]
graphium data prepare ++datamodule.args.processed_graph_data_path=[path_to_cached_data]

# Then train the model on the prepared data
graphium-train [...] datamodule.args.processed_graph_data_path=[path_to_cached_data]
Expand Down
67 changes: 64 additions & 3 deletions docs/baseline.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ToyMix Baseline
# ToyMix Baseline - Test set metrics

From the paper to be released soon. Below, you can see the baselines for the `ToyMix` dataset, a multitasking dataset comprising of `QM9`, `Zinc12k` and `Tox21`. The datasets and their splits are available on [this link](https://zenodo.org/record/7998401).
From the paper to be released soon. Below, you can see the baselines for the `ToyMix` dataset, a multitasking dataset comprising of `QM9`, `Zinc12k` and `Tox21`. The datasets and their splits are available on [this link](https://zenodo.org/record/7998401). The following baselines are all for models with ~150k parameters.

One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate from adding another unrelated task (`QM9`), where the labels are computed from DFT simulations.

Expand All @@ -25,7 +25,68 @@ One can observe that the smaller datasets (`Zinc12k` and `Tox21`) beneficiate fr
| | GINE | 0.201 ± 0.007 | 0.783 ± 0.007 | 0.345 ± 0.02 | 0.177 ± 0.0008 | 0.836 ± 0.004 | **0.455 ± 0.008** |

# LargeMix Baseline
Coming soon!
## LargeMix test set metrics

From the paper to be released soon. Below, you can see the baselines for the `LargeMix` dataset, a multitasking dataset comprising of `PCQM4M_N4`, `PCQM4M_G25`, `PCBA_1328`, `L1000_VCAP`, and `L1000_MCF7`. The datasets and their splits are available on [this link](https://zenodo.org/record/7998401). The following baselines are all for models with 4-6M parameters.

One can observe that the smaller datasets (`L1000_VCAP` and `L1000_MCF7`) beneficiate tremendously from the multitasking. Indeed, the lack of molecular samples means that it is very easy for a model to overfit.

While `PCQM4M_G25` has no noticeable changes, the node predictions of `PCQM4M_N4` and assay predictions of `PCBA_1328` take a hit, but it is most likely due to underfitting since the training loss is also increased. It seems that 4-6M parameters is far from sufficient to capturing all of the tasks simultaneously, which motivates the need for a larger model.

| Dataset | Model | MAE ↓ | Pearson ↑ | R² ↑ | MAE ↓ | Pearson ↑ | R² ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| | <th colspan="3" style="text-align: center;">Single-Task Model</th> <th colspan="3" style="text-align: center;">Multi-Task Model</th> |
| <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> |
| Pcqm4m_g25 | GCN | 0.2362 ± 0.0003 | 0.8781 ± 0.0005 | 0.7803 ± 0.0006 | 0.2458 ± 0.0007 | 0.8701 ± 0.0002 | **0.8189 ± 0.0726** |
| | GIN | 0.2270 ± 0.0003 | 0.8854 ± 0.0004 | 0.7912 ± 0.0006 | 0.2352 ± 0.0006 | 0.8802 ± 0.0007 | 0.7827 ± 0.0005 |
| | GINE| **0.2223 ± 0.0007** | **0.8874 ± 0.0003** | 0.7949 ± 0.0001 | 0.2315 ± 0.0002 | 0.8823 ± 0.0002 | 0.7864 ± 0.0008 |
| Pcqm4m_n4 | GCN | 0.2080 ± 0.0003 | 0.5497 ± 0.0010 | 0.2942 ± 0.0007 | 0.2040 ± 0.0001 | 0.4796 ± 0.0006 | 0.2185 ± 0.0002 |
| | GIN | 0.1912 ± 0.0027 | **0.6138 ± 0.0088** | **0.3688 ± 0.0116** | 0.1966 ± 0.0003 | 0.5198 ± 0.0008 | 0.2602 ± 0.0012 |
| | GINE| **0.1910 ± 0.0001** | 0.6127 ± 0.0003 | 0.3666 ± 0.0008 | 0.1941 ± 0.0003 | 0.5303 ± 0.0023 | 0.2701 ± 0.0034 |


| | | BCE ↓ | AUROC ↑ | AP ↑ | BCE ↓ | AUROC ↑ | AP ↑ |
|-----------|-------|-----------|-----------|-----------|---------|-----------|---------|
| | <th colspan="3" style="text-align: center;">Single-Task Model</th> <th colspan="3" style="text-align: center;">Multi-Task Model</th> |
| <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> | <hi> |
| Pcba\_1328 | GCN | **0.0316 ± 0.0000** | **0.7960 ± 0.0020** | **0.3368 ± 0.0027** | 0.0349 ± 0.0002 | 0.7661 ± 0.0031 | 0.2527 ± 0.0041 |
| | GIN | 0.0324 ± 0.0000 | 0.7941 ± 0.0018 | 0.3328 ± 0.0019 | 0.0342 ± 0.0001 | 0.7747 ± 0.0025 | 0.2650 ± 0.0020 |
| | GINE | 0.0320 ± 0.0001 | 0.7944 ± 0.0023 | 0.3337 ± 0.0027 | 0.0341 ± 0.0001 | 0.7737 ± 0.0007 | 0.2611 ± 0.0043 |
| L1000\_vcap | GCN | 0.1900 ± 0.0002 | 0.5788 ± 0.0034 | 0.3708 ± 0.0007 | 0.1872 ± 0.0020 | 0.6362 ± 0.0012 | 0.4022 ± 0.0008 |
| | GIN | 0.1909 ± 0.0005 | 0.5734 ± 0.0029 | 0.3731 ± 0.0014 | 0.1870 ± 0.0010 | 0.6351 ± 0.0014 | 0.4062 ± 0.0001 |
| | GINE | 0.1907 ± 0.0006 | 0.5708 ± 0.0079 | 0.3705 ± 0.0015 | **0.1862 ± 0.0007** | **0.6398 ± 0.0043** | **0.4068 ± 0.0023** |
| L1000\_mcf7 | GCN | 0.1869 ± 0.0003 | 0.6123 ± 0.0051 | 0.3866 ± 0.0010 | 0.1863 ± 0.0011 | **0.6401 ± 0.0021** | 0.4194 ± 0.0004 |
| | GIN | 0.1862 ± 0.0003 | 0.6202 ± 0.0091 | 0.3876 ± 0.0017 | 0.1874 ± 0.0013 | 0.6367 ± 0.0066 | **0.4198 ± 0.0036** |
| | GINE | **0.1856 ± 0.0005** | 0.6166 ± 0.0017 | 0.3892 ± 0.0035 | 0.1873 ± 0.0009 | 0.6347 ± 0.0048 | 0.4177 ± 0.0024 |

## LargeMix training set loss

Below is the loss on the training set. One can observe that the multi-task model always underfits the single-task, except on the two `L1000` datasets.

This is not surprising as they contain two orders of magnitude more datapoints and pose a significant challenge for the relatively small models used in this analysis. This favors the Single dataset setup (which uses a model of the same size) and we conjecture larger models to bridge this gap moving forward.

| | | CE or BCE loss in single-task $\downarrow$ | CE or BCE loss in multi-task $\downarrow$ |
|------------|-------|-----------------------------------------|-----------------------------------------|
| | | | |
| **Pcqm4m\_g25** | GCN | **0.2660 ± 0.0005** | 0.2767 ± 0.0015 |
| | GIN | **0.2439 ± 0.0004** | 0.2595 ± 0.0016 |
| | GINE | **0.2424 ± 0.0007** | 0.2568 ± 0.0012 |
| | | | |
| **Pcqm4m\_n4** | GCN | **0.2515 ± 0.0002** | 0.2613 ± 0.0008 |
| | GIN | **0.2317 ± 0.0003** | 0.2512 ± 0.0008 |
| | GINE | **0.2272 ± 0.0001** | 0.2483 ± 0.0004 |
| | | | |
| **Pcba\_1328** | GCN | **0.0284 ± 0.0010** | 0.0382 ± 0.0005 |
| | GIN | **0.0249 ± 0.0017** | 0.0359 ± 0.0011 |
| | GINE | **0.0258 ± 0.0017** | 0.0361 ± 0.0008 |
| | | | |
| **L1000\_vcap** | GCN | 0.1906 ± 0.0036 | **0.1854 ± 0.0148** |
| | GIN | 0.1854 ± 0.0030 | **0.1833 ± 0.0185** |
| | GINE | **0.1860 ± 0.0025** | 0.1887 ± 0.0200 |
| | | | |
| **L1000\_mcf7** | GCN | 0.1902 ± 0.0038 | **0.1829 ± 0.0095** |
| | GIN | 0.1873 ± 0.0033 | **0.1701 ± 0.0142** |
| | GINE | 0.1883 ± 0.0039 | **0.1771 ± 0.0010** |

# UltraLarge Baseline
Coming soon!
Expand Down
9 changes: 0 additions & 9 deletions docs/cli_references.md

This file was deleted.

4 changes: 2 additions & 2 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
dependencies:
- python >=3.8
- pip
- click
- typer
- loguru
- omegaconf >=2.0.0
- tqdm
Expand Down Expand Up @@ -66,10 +66,10 @@ dependencies:
- mkdocstrings
- mkdocstrings-python
- mkdocs-jupyter
- mkdocs-click
- markdown-include
- mike >=1.0.0

- pip:
- lightning-graphcore # optional, for using IPUs only
- hydra-core>=1.3.2
- hydra-optuna-sweeper
2 changes: 1 addition & 1 deletion expts/hydra-configs/architecture/toymix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ datamodule:
featurization_n_jobs: 30
featurization_progress: True
featurization_backend: "loky"
processed_graph_data_path: "../datacache/neurips2023-small/"
processed_graph_data_path: ${constants.datacache_path}
dataloading_from: ram
num_workers: 30 # -1 to use all
persistent_workers: False
Expand Down
10 changes: 5 additions & 5 deletions expts/hydra-configs/finetuning/admet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ constants:
# For now, we assume a model is always fine-tuned on a single task at a time.
# You can override this value with any of the benchmark names in the TDC benchmark suite.
# See also https://tdcommons.ai/benchmark/admet_group/overview/
task: &task lipophilicity_astrazeneca
task: lipophilicity_astrazeneca

name: finetuning_${constants.task}_gcn
wandb:
name: ${constants.name}
project: *task
project: ${constants.task}
entity: multitask-gnn
save_dir: logs/${constants.task}
seed: 42
max_epochs: 10
max_epochs: 100
data_dir: expts/data/admet/${constants.task}
raise_train_error: true

Expand All @@ -57,10 +57,10 @@ finetuning:
level: graph

# Pretrained model
pretrained_model_name: dummy-pretrained-model
pretrained_model: dummy-pretrained-model
finetuning_module: task_heads # gnn
sub_module_from_pretrained: zinc # optional
new_sub_module: lipophilicity_astrazeneca # optional
new_sub_module: ${constants.task} # optional

# keep_modules_after_finetuning_module: # optional
# graph_output_nn/graph: {}
Expand Down
54 changes: 54 additions & 0 deletions expts/hydra-configs/hparam_search/optuna.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# @package _global_
#
# For running a hyper-parameter search, we use the Optuna plugin for hydra.
# This makes optuna available as a sweeper in hydra and integrates easily with the rest of the codebase.
# For more info, see https://hydra.cc/docs/plugins/optuna_sweeper/
#
# To run a hyper-param search,
# (1) Update this config, specifically the hyper-param search space;
# (2) Run `graphium-train +hparam_search=optuna` from the command line.


defaults:
- override /hydra/sweeper: optuna
# Optuna supports various sweepers (e.g. grid search, random search, TPE sampler)
- override /hydra/sweeper/sampler: tpe

hyper_param_search:
# For the sweeper to work, the main process needs to return
# the objective value(s) (as a float) we are trying to optimize.

# Assuming this is a metric, the `objective` key specifies which metric.
# Optuna supports multi-parameter optimization as well.
# If configured correctly, you can specify multiple keys.
objective: loss/test

# Where to save results to
# NOTE (cwognum): Ideally, we would use the `hydra.sweep.dir` key, but they don't support remote paths.
# save_destination: gs://path/to/bucket
# overwrite_destination: false

hydra:
# Run in multirun mode by default (i.e. actually use the sweeper)
mode: MULTIRUN

# Changes the working directory
sweep:
dir: hparam-search-results/${constants.name}
subdir: ${hydra.job.num}

# Sweeper config
sweeper:
sampler:
seed: ${constants.seed}
direction: minimize
study_name: ${constants.name}
storage: null
n_trials: 100
n_jobs: 1

# The hyper-parameter search space definition
# See https://hydra.cc/docs/plugins/optuna_sweeper/#search-space-configuration for the options
params:
predictor.optim_kwargs.lr: tag(log, interval(0.00001, 0.001))

1 change: 1 addition & 0 deletions expts/hydra-configs/training/model/toymix_gcn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ constants:
max_epochs: 100
data_dir: expts/data/neurips2023/small-dataset
raise_train_error: true
datacache_path: ../datacache/neurips2023-small/

trainer:
model_checkpoint:
Expand Down
2 changes: 2 additions & 0 deletions expts/hydra-configs/training/model/toymix_gin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
constants:
name: neurips2023_small_data_gin
seed: 42
max_epochs: 100
data_dir: expts/data/neurips2023/small-dataset
raise_train_error: true
datacache_path: ../datacache/neurips2023-small/

trainer:
model_checkpoint:
Expand Down
30 changes: 1 addition & 29 deletions expts/main_run_multitask.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,5 @@
# General imports
import os
from os.path import dirname, abspath
from omegaconf import DictConfig, OmegaConf
import timeit
from loguru import logger
from datetime import datetime
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
from graphium.config._loader import (
load_datamodule,
load_metrics,
load_architecture,
load_predictor,
load_trainer,
save_params_to_wandb,
load_accelerator,
)
from graphium.utils.safe_run import SafeRun

import hydra

# WandB
import wandb

# Set up the working directory
MAIN_DIR = dirname(dirname(abspath(graphium.__file__)))
os.chdir(MAIN_DIR)
from omegaconf import DictConfig


@hydra.main(version_base=None, config_path="hydra-configs", config_name="main")
Expand Down
6 changes: 3 additions & 3 deletions graphium/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .data import data_cli
from .finetune_utils import finetune_cli
from .main import main_cli
from .data import data_app
from .finetune_utils import finetune_app
from .main import app
4 changes: 2 additions & 2 deletions graphium/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .main import main_cli
from .main import app

if __name__ == "__main__":
main_cli()
app()
76 changes: 41 additions & 35 deletions graphium/cli/data.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,22 @@
import click
import timeit
from typing import List
from omegaconf import OmegaConf
import typer
import graphium

from loguru import logger
from hydra import initialize, compose

import graphium
from .main import app
from graphium.config._loader import load_datamodule


data_app = typer.Typer(help="Graphium datasets.")
app.add_typer(data_app, name="data")

from .main import main_cli


@main_cli.group(name="data", help="Graphium datasets.")
def data_cli():
pass


@data_cli.command(name="download", help="Download a Graphium dataset.")
@click.option(
"-n",
"--name",
type=str,
required=True,
help="Name of the graphium dataset to download.",
)
@click.option(
"-o",
"--output",
type=str,
required=True,
help="Where to download the Graphium dataset.",
)
@click.option(
"--progress",
type=bool,
is_flag=True,
default=False,
required=False,
help="Whether to extract the dataset if it's a zip file.",
)
def download(name, output, progress):

@data_app.command(name="download", help="Download a Graphium dataset.")
def download(name: str, output: str, progress: bool = True):
args = {}
args["name"] = name
args["output_path"] = output
Expand All @@ -49,7 +30,32 @@ def download(name, output, progress):
logger.info(f"Dataset available at {fpath}.")


@data_cli.command(name="list", help="List available Graphium dataset.")
@data_app.command(name="list", help="List available Graphium dataset.")
def list():
logger.info("Graphium datasets:")
logger.info(graphium.data.utils.list_graphium_datasets())


@data_app.command(name="prepare", help="Prepare a Graphium dataset.")
def prepare_data(overrides: List[str]) -> None:
with initialize(version_base=None, config_path="../../expts/hydra-configs"):
cfg = compose(
config_name="main",
overrides=overrides,
)
cfg = OmegaConf.to_container(cfg, resolve=True)
st = timeit.default_timer()

# Checking that `processed_graph_data_path` is provided
path = cfg["datamodule"]["args"].get("processed_graph_data_path", None)
if path is None:
raise ValueError(
"Please provide `datamodule.args.processed_graph_data_path` to specify the caching dir."
)
logger.info(f"The caching dir is set to '{path}'")

# Data-module
datamodule = load_datamodule(cfg, "cpu")
datamodule.prepare_data()

logger.info(f"Data preparation took {timeit.default_timer() - st:.2f} seconds.")
Loading

0 comments on commit 8358a0d

Please sign in to comment.