Skip to content

Commit

Permalink
Merge pull request #517 from datamol-io/torchmetrics
Browse files Browse the repository at this point in the history
Torchmetrics usage improvements with classes instead of functionals
  • Loading branch information
AnujaSomthankar authored Sep 10, 2024
2 parents c23dc02 + ce4f94d commit f723632
Show file tree
Hide file tree
Showing 40 changed files with 1,563 additions and 2,565 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ jobs:
- name: Install library
run: python -m pip install --no-deps -e . # `-e` required for correct `coverage` run.

- name: Install test dependencies
run: micromamba install -c conda-forge pytdc # Required to run the `test_finetuning.py`

- name: Install C++ library
run: cd graphium/graphium_cpp && git clone https://github.com/pybind/pybind11.git && export PYTHONPATH=$PYTHONPATH:./pybind11 && python -m pip install . && cd ../..

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ graphium/data/cache/
graphium/data/b3lyp/
graphium/data/PCQM4Mv2/
graphium/data/PCQM4M/
graphium/data/largemix/
graphium/data/neurips2023/small-dataset/
graphium/data/neurips2023/large-dataset/
graphium/data/neurips2023/dummy-dataset/
Expand Down
5 changes: 0 additions & 5 deletions docs/api/graphium.ipu.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ Code for adapting to run on IPU
::: graphium.ipu.ipu_losses


## IPU Metrics
------------
::: graphium.ipu.ipu_metrics


## IPU Simple Lightning
------------
::: graphium.ipu.ipu_simple_lightning
Expand Down
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- cuda-version # works also with CPU-only system.
- pytorch >=1.12
- lightning >=2.0
- torchmetrics >=0.7.0,<0.11
- torchmetrics
- ogb
- pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric`
- wandb
Expand Down
24 changes: 16 additions & 8 deletions expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# @package _global_

predictor:
target_nan_mask: ignore
multitask_handling: flatten
metrics_on_progress_bar:
qm9: ["mae"]
tox21: ["auroc"]
Expand All @@ -13,29 +15,31 @@ predictor:
metrics:
qm9: &qm9_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
metric: mae
target_nan_mask: ignore
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
metric: r2_score
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null
tox21:
- name: auroc
metric: auroc_ipu
metric: auroc
task: binary
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
metric: averageprecision
task: binary
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null
- name: f1 > 0.5
Expand All @@ -44,6 +48,8 @@ metrics:
target_to_int: True
num_classes: 2
average: micro
task: binary
target_nan_mask: ignore
threshold_kwargs: &threshold_05
operator: greater
threshold: 0.5
Expand All @@ -53,6 +59,8 @@ metrics:
metric: precision
multitask_handling: mean-per-label
average: micro
task: binary
target_nan_mask: ignore
threshold_kwargs: *threshold_05
zinc: *qm9_metrics

Expand Down
2 changes: 1 addition & 1 deletion expts/hydra-configs/training/toymix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ trainer:
precision: 16
max_epochs: ${constants.max_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
check_val_every_n_epoch: 2
2 changes: 1 addition & 1 deletion expts/run_validation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import timeit
from loguru import logger
from datetime import datetime
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
Expand Down
49 changes: 2 additions & 47 deletions graphium/cli/train_finetune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,41 +52,6 @@ def cli(cfg: DictConfig) -> None:
return run_training_finetuning_testing(cfg)


def get_replication_factor(cfg):
try:
ipu_config = cfg.get("accelerator", {}).get("ipu_config", [])
for item in ipu_config:
if "replicationFactor" in item:
# Extract the number between parentheses
start = item.find("(") + 1
end = item.find(")")
if start != 0 and end != -1:
return int(item[start:end])
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if replicationFactor is not found or an error occurred
return 1


def get_gradient_accumulation_factor(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
(i.e. after `load_accelerator` has been called)
"""
try:
# Navigate through the nested dictionaries and get the gradient accumulation factor
grad_accumulation_factor = cfg.get("trainer", {}).get("trainer", {}).get("accumulate_grad_batches", 1)

# Ensure that the extracted value is an integer
return int(grad_accumulation_factor)
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if an error occurred
return 1


def get_training_batch_size(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
Expand Down Expand Up @@ -195,14 +160,6 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
## Metrics
metrics = load_metrics(cfg)

# Note: these MUST be called after `cfg, accelerator = load_accelerator(cfg)`
replicas = get_replication_factor(cfg)
gradient_acc = get_gradient_accumulation_factor(cfg)
micro_bs = get_training_batch_size(cfg)
device_iterations = get_training_device_iterations(cfg)

global_bs = replicas * gradient_acc * micro_bs * device_iterations

## Predictor
predictor = load_predictor(
config=cfg,
Expand All @@ -213,17 +170,15 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
accelerator_type=accelerator_type,
featurization=datamodule.featurization,
task_norms=datamodule.task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
)

logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))
metrics_on_progress_bar = predictor.get_metrics_on_progress_bar

## Trainer
date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
trainer = load_trainer(cfg, accelerator_type, date_time_suffix)
trainer = load_trainer(cfg, accelerator_type, date_time_suffix, metrics_on_progress_bar=metrics_on_progress_bar)

if not testing_only:
# Add the fine-tuning callback to trainer
Expand Down
23 changes: 11 additions & 12 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Misc
import os
from copy import deepcopy
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union, Iterable

import joblib
import mup
Expand All @@ -40,10 +40,10 @@
from graphium.trainer.metrics import MetricWrapper
from graphium.trainer.predictor import PredictorModule
from graphium.utils.command_line_utils import get_anchors_and_aliases, update_config
from graphium.trainer.progress_bar import ProgressBarMetrics

# Graphium
from graphium.utils.mup import set_base_shapes
from graphium.utils.spaces import DATAMODULE_DICT, GRAPHIUM_PRETRAINED_MODELS_DICT
from graphium.utils import fs


Expand Down Expand Up @@ -111,6 +111,8 @@ def load_datamodule(
datamodule: The datamodule used to process and load the data
"""

from graphium.utils.spaces import DATAMODULE_DICT # Avoid circular imports with `spaces.py`

cfg_data = config["datamodule"]["args"]

# Instanciate the datamodule
Expand Down Expand Up @@ -298,9 +300,6 @@ def load_predictor(
accelerator_type: str,
featurization: Dict[str, str] = None,
task_norms: Optional[Dict[Callable, Any]] = None,
replicas: int = 1,
gradient_acc: int = 1,
global_bs: int = 1,
) -> PredictorModule:
"""
Defining the predictor module, which handles the training logic from `lightning.LighningModule`
Expand All @@ -326,9 +325,6 @@ def load_predictor(
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
**cfg_pred,
)

Expand All @@ -345,9 +341,6 @@ def load_predictor(
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
**cfg_pred,
)

Expand Down Expand Up @@ -384,6 +377,7 @@ def load_trainer(
config: Union[omegaconf.DictConfig, Dict[str, Any]],
accelerator_type: str,
date_time_suffix: str = "",
metrics_on_progress_bar: Optional[Iterable[str]] = None,
) -> Trainer:
"""
Defining the pytorch-lightning Trainer module.
Expand Down Expand Up @@ -449,12 +443,15 @@ def load_trainer(
name += f"_{date_time_suffix}"
trainer_kwargs["logger"] = WandbLogger(name=name, **wandb_cfg)

trainer_kwargs["callbacks"] = callbacks
progress_bar_callback = ProgressBarMetrics(metrics_on_progress_bar = metrics_on_progress_bar)
callbacks.append(progress_bar_callback)

trainer = Trainer(
detect_anomaly=True,
strategy=strategy,
accelerator=accelerator_type,
devices=devices,
callbacks=callbacks,
**cfg_trainer["trainer"],
**trainer_kwargs,
)
Expand Down Expand Up @@ -625,6 +622,8 @@ def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, Any]]) ->
Otherwise, assume it refers to a file in the checkpointing dir.
"""

from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT # Avoid circular imports with `spaces.py`

cfg_trainer = config["trainer"]

path = config.get("ckpt_name_for_testing", "last.ckpt")
Expand Down
24 changes: 12 additions & 12 deletions graphium/config/dummy_finetuning_from_gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ finetuning:

constants:
seed: 42
max_epochs: 2
max_epochs: 5

accelerator:
float32_matmul_precision: medium
Expand All @@ -64,14 +64,14 @@ accelerator:
predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5
lr: 1.e-3
scheduler_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: flatten # flatten, mean-per-label

torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: 2
max_num_epochs: 4
warmup_epochs: 1
verbose: False

Expand All @@ -84,30 +84,30 @@ metrics:
lipophilicity_astrazeneca:
- name: mae
metric: mae
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: flatten
threshold_kwargs: null
- name: spearman
metric: spearmanr
threshold_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
- name: pearson
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null

trainer:
seed: ${constants.seed}
trainer:
precision: 32
max_epochs: 2
max_epochs: 5
min_epochs: 1
check_val_every_n_epoch: 1
accumulate_grad_batches: 1
Expand All @@ -122,12 +122,12 @@ datamodule:

module_type: "ADMETBenchmarkDataModule"
args:
processed_graph_data_path: datacache/processed_graph_data/dummy_finetuning_from_gnn
# TDC specific
tdc_benchmark_names: [lipophilicity_astrazeneca]
tdc_train_val_seed: ${constants.seed}

batch_size_training: 200
batch_size_inference: 200
batch_size_training: 20
batch_size_inference: 20
num_workers: 0

persistent_workers: False
Loading

0 comments on commit f723632

Please sign in to comment.