Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchmetrics usage improvements with classes instead of functionals #517

Merged
merged 71 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
bd59098
Removed ipu metrics, since not compatible with latest torchmetrics
DomInvivo Apr 26, 2024
734ba55
Updated `MetricWrapper` to work with `update` and `compute`, compatib…
DomInvivo Apr 26, 2024
b6c578f
Changed requirements for torchmetrics
DomInvivo Apr 26, 2024
4f6e816
fixed the loss by adding `MetricToTorchMetrics`, and added a few comm…
DomInvivo Apr 26, 2024
7933ae5
Major updates to `predictor_summaries.py`
DomInvivo May 3, 2024
5849927
Improved the predictor summaries. Added GradientNormMetric
DomInvivo May 3, 2024
f355eed
Improved the task summaries and started to fix the training logging.
DomInvivo Jun 13, 2024
4372ace
Fixed test_metrics. Moved lots of `spaces.py` imports to inner functi…
DomInvivo Jul 10, 2024
7b89998
duplicated some unit-test fixes from graphium_3.0 branch
DomInvivo Jul 11, 2024
ab88952
Fixed the loading of a previous dummy model using older metrics by re…
DomInvivo Jul 11, 2024
6c58733
Minor documentation
DomInvivo Jul 11, 2024
a9a8810
Removed the loss from `predictor_summaries`
DomInvivo Jul 11, 2024
2185697
Removed epochs from task summaries
DomInvivo Jul 11, 2024
d37d818
Draft implementing the update/compute logic in the predictor.
DomInvivo Jul 11, 2024
b4524f9
Fix the std metric. Still needs testing.
DomInvivo Jul 11, 2024
5040c47
fixed all errors arising in `test_finetuning.py`
DomInvivo Jul 11, 2024
e761e08
Fixed the `test_training.py` unit test
DomInvivo Jul 12, 2024
5d60fbf
Standardized the test names
DomInvivo Jul 12, 2024
b59428a
Fixed some unit-tests that were broken by previous changes
DomInvivo Jul 12, 2024
632d4dc
Added `pytdc` to the tests
DomInvivo Jul 12, 2024
0fa2d86
Changed mamba install tdc to pip install, in the `test.yml` file
DomInvivo Jul 12, 2024
2441f43
Added '--no-deps' to TDC installation in `test.yml`
DomInvivo Jul 12, 2024
326b6e7
Woops
DomInvivo Jul 12, 2024
641fa37
Fixed issue with building docs
DomInvivo Jul 12, 2024
2b85dce
Removed old file from breaking docs building
DomInvivo Jul 12, 2024
0c93a0f
Changed to micromamba to install pytdc
DomInvivo Jul 12, 2024
ec235fc
Added tests for the `STDMetric` and `GradientNormMetric` and fixed th…
DomInvivo Jul 12, 2024
38d03e1
Implemented test of MultiTaskSummaries. Only an error left for the me…
DomInvivo Jul 12, 2024
d6f62a4
Fixed the `preds` and `targets` that were inverted in `TaskSummary`
DomInvivo Jul 13, 2024
3673884
Tried to add grad_norm to the metrics, but won't work because it's no…
DomInvivo Jul 13, 2024
29598a2
Moved the gradient metric directly to the `Predictor`
DomInvivo Jul 13, 2024
6260fa1
Removed file_opener and read_file
DomInvivo Jul 13, 2024
10a1017
Fixed predictor grad_norm
DomInvivo Jul 13, 2024
8aa0f2b
Merge branch 'graphium_3.0' into torchmetrics
DomInvivo Jul 13, 2024
90c0ca4
Fixed the progress bar logging to newest version. Fixed minor issues …
DomInvivo Jul 15, 2024
be99d94
Merge remote-tracking branch 'origin/torchmetrics' into torchmetrics
DomInvivo Jul 15, 2024
44b66b5
fixed some issue with older version of torchmetrics
DomInvivo Jul 15, 2024
5c421a6
Fixed reversed preds/targets. Fixed random sampling to take in the DF…
DomInvivo Jul 16, 2024
f15cd9a
fixed missing metrics computation on `on_train_batch_end`
DomInvivo Jul 16, 2024
2142313
Added toymix training to the unit-tests. Also useful to run in debug …
DomInvivo Jul 16, 2024
99e0cd6
Adding `_global/` to some metrics logging into wandb
DomInvivo Jul 16, 2024
045ea53
Added better handling of metrics failure with `logger.warn`
DomInvivo Jul 16, 2024
d8ba606
Fixed metric issues on gpu by casting to the right device prior to `.…
DomInvivo Jul 16, 2024
1bf2734
Added losses to the metrics, such that they are computed on val and t…
DomInvivo Jul 17, 2024
68b9361
Restricting the numpy version due to issues with wandb
DomInvivo Jul 17, 2024
911dfe9
detaching preds
DomInvivo Jul 17, 2024
d34ac60
Removed cuda version restriction
DomInvivo Jul 17, 2024
b1f2e86
Removed unnecessary detach, that broke the loss
DomInvivo Jul 17, 2024
9dbd021
Minor gitignore
DomInvivo Aug 15, 2024
5a77cbe
Fixed the error due to time metrics on CPU `No backend type associate…
DomInvivo Aug 16, 2024
7fba29d
Added val epoch time
DomInvivo Aug 16, 2024
b59dc36
Added logic to avoid crashing when resetting unused metrics
DomInvivo Aug 17, 2024
da3e3a1
Added `MetricWrapper.device`
DomInvivo Aug 19, 2024
8bf0d41
Disable caching model checkpoint through WandbLogger
Aug 19, 2024
6f35ea9
Improved the testing of the metrics reset, update, compute
DomInvivo Aug 21, 2024
d2f84f2
Reverted wrong change in `train_finetune_test.py
DomInvivo Aug 21, 2024
e9be441
Improved __len__ in MultitaskDataModule
DomInvivo Aug 21, 2024
eaf9077
Added a new logic to allow saving all preds and targets more efficien…
DomInvivo Aug 22, 2024
5432531
Fixed the concatenation to work with and without DDP. Moved to CPU fo…
DomInvivo Aug 22, 2024
8c75d77
Fixed the issue with memory leaks and devices.
DomInvivo Aug 22, 2024
5abd769
Fixed the CPU syncing of `MetricToConcatenatedTorchMetrics` and GPU f…
DomInvivo Aug 22, 2024
fac3052
Fixed the training metrics, and grouped all epoch-time and tput metrics
DomInvivo Aug 22, 2024
d0ed816
Fixed epoch_time tracking (because train ends after val)
DomInvivo Aug 22, 2024
9b7063f
Using the `torchmetrics.Metric.sync` instead of torch_distributed
DomInvivo Aug 23, 2024
136b8b0
Fixed issue that NaNs are always removed with `mean-per-label`
DomInvivo Aug 29, 2024
2724b4c
Changed the name of logging variables
DomInvivo Aug 29, 2024
141f48b
Removed some IPU logic
DomInvivo Aug 29, 2024
62f2224
Fixed the syncing of `MetricToConcatenatedTorchMetrics`
DomInvivo Aug 29, 2024
2b58fed
Fixed classification metric calculation when multitask_handling=flatten
AnujaSomthankar Aug 29, 2024
2fb7f4b
Fixed all unit-test, except those for IPU
DomInvivo Sep 7, 2024
ce4f94d
Merge branch 'graphium_3.0' into torchmetrics
AnujaSomthankar Sep 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ jobs:
- name: Install library
run: python -m pip install --no-deps -e . # `-e` required for correct `coverage` run.

- name: Install test dependencies
run: micromamba install -c conda-forge pytdc # Required to run the `test_finetuning.py`

- name: Install C++ library
run: cd graphium/graphium_cpp && git clone https://github.com/pybind/pybind11.git && export PYTHONPATH=$PYTHONPATH:./pybind11 && python -m pip install . && cd ../..

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ graphium/data/cache/
graphium/data/b3lyp/
graphium/data/PCQM4Mv2/
graphium/data/PCQM4M/
graphium/data/largemix/
graphium/data/neurips2023/small-dataset/
graphium/data/neurips2023/large-dataset/
graphium/data/neurips2023/dummy-dataset/
Expand Down
5 changes: 0 additions & 5 deletions docs/api/graphium.ipu.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ Code for adapting to run on IPU
::: graphium.ipu.ipu_losses


## IPU Metrics
------------
::: graphium.ipu.ipu_metrics


## IPU Simple Lightning
------------
::: graphium.ipu.ipu_simple_lightning
Expand Down
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies:
- cuda-version # works also with CPU-only system.
- pytorch >=1.12
- lightning >=2.0
- torchmetrics >=0.7.0,<0.11
- torchmetrics
- ogb
- pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric`
- wandb
Expand Down
24 changes: 16 additions & 8 deletions expts/hydra-configs/tasks/loss_metrics_datamodule/toymix.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# @package _global_

predictor:
target_nan_mask: ignore
multitask_handling: flatten
metrics_on_progress_bar:
qm9: ["mae"]
tox21: ["auroc"]
Expand All @@ -13,29 +15,31 @@ predictor:
metrics:
qm9: &qm9_metrics
- name: mae
metric: mae_ipu
target_nan_mask: null
metric: mae
target_nan_mask: ignore
multitask_handling: flatten
threshold_kwargs: null
- name: pearsonr
metric: pearsonr_ipu
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score_ipu
target_nan_mask: null
metric: r2_score
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null
tox21:
- name: auroc
metric: auroc_ipu
metric: auroc
task: binary
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null
- name: avpr
metric: average_precision_ipu
metric: averageprecision
task: binary
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null
- name: f1 > 0.5
Expand All @@ -44,6 +48,8 @@ metrics:
target_to_int: True
num_classes: 2
average: micro
task: binary
target_nan_mask: ignore
threshold_kwargs: &threshold_05
operator: greater
threshold: 0.5
Expand All @@ -53,6 +59,8 @@ metrics:
metric: precision
multitask_handling: mean-per-label
average: micro
task: binary
target_nan_mask: ignore
threshold_kwargs: *threshold_05
zinc: *qm9_metrics

Expand Down
2 changes: 1 addition & 1 deletion expts/hydra-configs/training/toymix.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ trainer:
precision: 16
max_epochs: ${constants.max_epochs}
min_epochs: 1
check_val_every_n_epoch: 20
check_val_every_n_epoch: 2
2 changes: 1 addition & 1 deletion expts/run_validation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import timeit
from loguru import logger
from datetime import datetime
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
Expand Down
49 changes: 2 additions & 47 deletions graphium/cli/train_finetune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,41 +52,6 @@ def cli(cfg: DictConfig) -> None:
return run_training_finetuning_testing(cfg)


def get_replication_factor(cfg):
try:
ipu_config = cfg.get("accelerator", {}).get("ipu_config", [])
for item in ipu_config:
if "replicationFactor" in item:
# Extract the number between parentheses
start = item.find("(") + 1
end = item.find(")")
if start != 0 and end != -1:
return int(item[start:end])
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if replicationFactor is not found or an error occurred
return 1


def get_gradient_accumulation_factor(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
(i.e. after `load_accelerator` has been called)
"""
try:
# Navigate through the nested dictionaries and get the gradient accumulation factor
grad_accumulation_factor = cfg.get("trainer", {}).get("trainer", {}).get("accumulate_grad_batches", 1)

# Ensure that the extracted value is an integer
return int(grad_accumulation_factor)
except Exception as e:
print(f"An error occurred: {e}")

# Return default value if an error occurred
return 1


def get_training_batch_size(cfg):
"""
WARNING: This MUST be called after accelerator overrides have been applied
Expand Down Expand Up @@ -195,14 +160,6 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
## Metrics
metrics = load_metrics(cfg)

# Note: these MUST be called after `cfg, accelerator = load_accelerator(cfg)`
replicas = get_replication_factor(cfg)
gradient_acc = get_gradient_accumulation_factor(cfg)
micro_bs = get_training_batch_size(cfg)
device_iterations = get_training_device_iterations(cfg)

global_bs = replicas * gradient_acc * micro_bs * device_iterations

## Predictor
predictor = load_predictor(
config=cfg,
Expand All @@ -213,17 +170,15 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
accelerator_type=accelerator_type,
featurization=datamodule.featurization,
task_norms=datamodule.task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
)

logger.info(predictor.model)
logger.info(ModelSummary(predictor, max_depth=4))
metrics_on_progress_bar = predictor.get_metrics_on_progress_bar

## Trainer
date_time_suffix = datetime.now().strftime("%d.%m.%Y_%H.%M.%S")
trainer = load_trainer(cfg, accelerator_type, date_time_suffix)
trainer = load_trainer(cfg, accelerator_type, date_time_suffix, metrics_on_progress_bar=metrics_on_progress_bar)

if not testing_only:
# Add the fine-tuning callback to trainer
Expand Down
23 changes: 11 additions & 12 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Misc
import os
from copy import deepcopy
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union, Iterable

import joblib
import mup
Expand All @@ -40,10 +40,10 @@
from graphium.trainer.metrics import MetricWrapper
from graphium.trainer.predictor import PredictorModule
from graphium.utils.command_line_utils import get_anchors_and_aliases, update_config
from graphium.trainer.progress_bar import ProgressBarMetrics

# Graphium
from graphium.utils.mup import set_base_shapes
from graphium.utils.spaces import DATAMODULE_DICT, GRAPHIUM_PRETRAINED_MODELS_DICT
from graphium.utils import fs


Expand Down Expand Up @@ -111,6 +111,8 @@ def load_datamodule(
datamodule: The datamodule used to process and load the data
"""

from graphium.utils.spaces import DATAMODULE_DICT # Avoid circular imports with `spaces.py`

cfg_data = config["datamodule"]["args"]

# Instanciate the datamodule
Expand Down Expand Up @@ -298,9 +300,6 @@ def load_predictor(
accelerator_type: str,
featurization: Dict[str, str] = None,
task_norms: Optional[Dict[Callable, Any]] = None,
replicas: int = 1,
gradient_acc: int = 1,
global_bs: int = 1,
) -> PredictorModule:
"""
Defining the predictor module, which handles the training logic from `lightning.LighningModule`
Expand All @@ -326,9 +325,6 @@ def load_predictor(
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
**cfg_pred,
)

Expand All @@ -345,9 +341,6 @@ def load_predictor(
task_levels=task_levels,
featurization=featurization,
task_norms=task_norms,
replicas=replicas,
gradient_acc=gradient_acc,
global_bs=global_bs,
**cfg_pred,
)

Expand Down Expand Up @@ -384,6 +377,7 @@ def load_trainer(
config: Union[omegaconf.DictConfig, Dict[str, Any]],
accelerator_type: str,
date_time_suffix: str = "",
metrics_on_progress_bar: Optional[Iterable[str]] = None,
) -> Trainer:
"""
Defining the pytorch-lightning Trainer module.
Expand Down Expand Up @@ -449,12 +443,15 @@ def load_trainer(
name += f"_{date_time_suffix}"
trainer_kwargs["logger"] = WandbLogger(name=name, **wandb_cfg)

trainer_kwargs["callbacks"] = callbacks
progress_bar_callback = ProgressBarMetrics(metrics_on_progress_bar = metrics_on_progress_bar)
callbacks.append(progress_bar_callback)

trainer = Trainer(
detect_anomaly=True,
strategy=strategy,
accelerator=accelerator_type,
devices=devices,
callbacks=callbacks,
**cfg_trainer["trainer"],
**trainer_kwargs,
)
Expand Down Expand Up @@ -625,6 +622,8 @@ def get_checkpoint_path(config: Union[omegaconf.DictConfig, Dict[str, Any]]) ->
Otherwise, assume it refers to a file in the checkpointing dir.
"""

from graphium.utils.spaces import GRAPHIUM_PRETRAINED_MODELS_DICT # Avoid circular imports with `spaces.py`

cfg_trainer = config["trainer"]

path = config.get("ckpt_name_for_testing", "last.ckpt")
Expand Down
24 changes: 12 additions & 12 deletions graphium/config/dummy_finetuning_from_gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ finetuning:

constants:
seed: 42
max_epochs: 2
max_epochs: 5

accelerator:
float32_matmul_precision: medium
Expand All @@ -64,14 +64,14 @@ accelerator:
predictor:
random_seed: ${constants.seed}
optim_kwargs:
lr: 4.e-5
lr: 1.e-3
scheduler_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: flatten # flatten, mean-per-label

torch_scheduler_kwargs:
module_type: WarmUpLinearLR
max_num_epochs: 2
max_num_epochs: 4
warmup_epochs: 1
verbose: False

Expand All @@ -84,30 +84,30 @@ metrics:
lipophilicity_astrazeneca:
- name: mae
metric: mae
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: flatten
threshold_kwargs: null
- name: spearman
metric: spearmanr
threshold_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
- name: pearson
metric: pearsonr
threshold_kwargs: null
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
- name: r2_score
metric: r2_score
target_nan_mask: null
target_nan_mask: ignore
multitask_handling: mean-per-label
threshold_kwargs: null

trainer:
seed: ${constants.seed}
trainer:
precision: 32
max_epochs: 2
max_epochs: 5
min_epochs: 1
check_val_every_n_epoch: 1
accumulate_grad_batches: 1
Expand All @@ -122,12 +122,12 @@ datamodule:

module_type: "ADMETBenchmarkDataModule"
args:
processed_graph_data_path: datacache/processed_graph_data/dummy_finetuning_from_gnn
# TDC specific
tdc_benchmark_names: [lipophilicity_astrazeneca]
tdc_train_val_seed: ${constants.seed}

batch_size_training: 200
batch_size_inference: 200
batch_size_training: 20
batch_size_inference: 20
num_workers: 0

persistent_workers: False
Loading
Loading