Skip to content

Commit

Permalink
Merge pull request #323 from datamol-io/lightning_v2
Browse files Browse the repository at this point in the history
Lightning v2
  • Loading branch information
DomInvivo authored Jun 23, 2023
2 parents 892de3b + f2c2b51 commit 2566355
Show file tree
Hide file tree
Showing 23 changed files with 96 additions and 121 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/model_training/ipu_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@
"from omegaconf import DictConfig\n",
"import timeit\n",
"from loguru import logger\n",
"from pytorch_lightning.utilities.model_summary import ModelSummary\n",
"from lightning.pytorch.utilities.model_summary import ModelSummary\n",
"\n",
"# Current project imports\n",
"import graphium\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/model_training/ipu_training_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from omegaconf import DictConfig
import timeit
from loguru import logger
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
Expand Down
8 changes: 7 additions & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- pandas >=1.0
- scikit-learn
- fastparquet
- sympy

# viz
- matplotlib >=3.0.1
Expand All @@ -30,8 +31,9 @@ dependencies:
# ML packages
- cudatoolkit # works also with CPU-only system.
- pytorch >=1.10.2,<2.0
- tensorboard
- lightning >=2.0
- torchvision
- pytorch-lightning >=1.9
- torchmetrics >=0.7.0,<0.11
- ogb
- pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric`
Expand Down Expand Up @@ -64,3 +66,7 @@ dependencies:
- mkdocs-click
- markdown-include
- mike >=1.0.0

- pip
- pip:
- lightning-graphcore # optional, for using IPUs only
2 changes: 1 addition & 1 deletion expts/configs/config_mpnn_10M_pcqm4m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ accelerator:
loss_scaling: 1024
trainer:
trainer:
precision: 16
precision: 16-true
accumulate_grad_batches: 4

ipu_config:
Expand Down
2 changes: 1 addition & 1 deletion expts/main_run_get_fingerprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import torch
import fsspec
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
Expand Down
2 changes: 1 addition & 1 deletion expts/main_run_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import timeit
from loguru import logger
from datetime import datetime
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
Expand Down
2 changes: 1 addition & 1 deletion expts/main_run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import pandas as pd
import torch
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
Expand Down
2 changes: 1 addition & 1 deletion expts/main_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import yaml
from copy import deepcopy
from omegaconf import DictConfig
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning.pytorch.utilities.model_summary import ModelSummary

# Current project imports
import graphium
Expand Down
48 changes: 13 additions & 35 deletions graphium/config/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import mup

# Lightning
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger, Logger
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger, Logger

# Graphium
from graphium.utils.mup import set_base_shapes
Expand All @@ -31,37 +31,24 @@
from graphium.data.datamodule import MultitaskFromSmilesDataModule, BaseDataModule
from graphium.utils.command_line_utils import update_config, get_anchors_and_aliases

# Weights and Biases
from pytorch_lightning import Trainer


def get_accelerator(
config_acc: Union[omegaconf.DictConfig, Dict[str, Any]],
) -> str:
"""
Get the accelerator from the config file, and ensure that they are
consistant. For example, specifying `cpu` as the accelerators, but
`gpus>0` as a Trainer option will yield an error.
consistant.
"""

# Get the accelerator type
accelerator_type = config_acc["type"]

# Get the GPU info
gpus = config_acc["config_override"].get("trainer", {}).get("trainer", {}).get("gpus", 0)
if gpus > 0:
assert (accelerator_type is None) or (accelerator_type == "gpu"), "Accelerator mismatch"
accelerator_type = "gpu"

if (accelerator_type == "gpu") and (not torch.cuda.is_available()):
logger.warning(f"GPUs selected, but will be ignored since no GPU are available on this device")
accelerator_type = "cpu"

# Get the IPU info
ipus = config_acc["config_override"].get("trainer", {}).get("trainer", {}).get("ipus", 0)
if ipus > 0:
assert (accelerator_type is None) or (accelerator_type == "ipu"), "Accelerator mismatch"
accelerator_type = "ipu"
if accelerator_type == "ipu":
poptorch = import_poptorch()
if not poptorch.ipuHardwareIsAvailable():
Expand Down Expand Up @@ -280,7 +267,7 @@ def load_predictor(
task_norms: Optional[Dict[Callable, Any]] = None,
) -> PredictorModule:
"""
Defining the predictor module, which handles the training logic from `pytorch_lightning.LighningModule`
Defining the predictor module, which handles the training logic from `lightning.LighningModule`
Parameters:
model_class: The torch Module containing the main forward function
accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu"
Expand Down Expand Up @@ -365,7 +352,7 @@ def load_trainer(
cfg_trainer = deepcopy(config["trainer"])

# Define the IPU plugin if required
strategy = None
strategy = "auto"
if accelerator_type == "ipu":
ipu_opts, ipu_inference_opts = _get_ipu_opts(config)

Expand All @@ -377,22 +364,14 @@ def load_trainer(
gradient_accumulation=config["trainer"]["trainer"].get("accumulate_grad_batches", None),
)

from graphium.ipu.ipu_wrapper import DictIPUStrategy
from lightning_graphcore import IPUStrategy

strategy = DictIPUStrategy(training_opts=training_opts, inference_opts=inference_opts)
strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts)

# Set the number of gpus to 0 if no GPU is available
_ = cfg_trainer["trainer"].pop("accelerator", None)
gpus = cfg_trainer["trainer"].pop("gpus", None)
ipus = cfg_trainer["trainer"].pop("ipus", None)
if (accelerator_type == "gpu") and (gpus is None):
gpus = 1
if (accelerator_type == "ipu") and (ipus is None):
ipus = 1
if accelerator_type != "gpu":
gpus = 0
if accelerator_type != "ipu":
ipus = 0
# Get devices
devices = cfg_trainer["trainer"].pop("devices", 1)
if accelerator_type == "ipu":
devices = 1 # number of IPUs used is defined in the ipu options files

# Remove the gradient accumulation from IPUs, since it's handled by the device
if accelerator_type == "ipu":
Expand Down Expand Up @@ -422,8 +401,7 @@ def load_trainer(
detect_anomaly=True,
strategy=strategy,
accelerator=accelerator_type,
ipus=ipus,
gpus=gpus,
devices=devices,
**cfg_trainer["trainer"],
**trainer_kwargs,
)
Expand Down
6 changes: 3 additions & 3 deletions graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from sklearn.model_selection import train_test_split

import pytorch_lightning as pl
from pytorch_lightning.trainer.states import RunningStage
import lightning
from lightning.pytorch.trainer.states import RunningStage

import torch
from torch_geometric.data import Data
Expand Down Expand Up @@ -81,7 +81,7 @@
)


class BaseDataModule(pl.LightningDataModule):
class BaseDataModule(lightning.LightningDataModule):
def __init__(
self,
batch_size_training: int = 16,
Expand Down
2 changes: 1 addition & 1 deletion graphium/data/multilevel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def merge_columns(data: pd.Series):
data = data.to_list()
data = [np.array([np.nan]) if not isinstance(d, np.ndarray) and math.isnan(d) else d for d in data]
padded_data = itertools.zip_longest(*data, fillvalue=np.nan)
data = np.stack(padded_data, 1).T
data = np.stack(list(padded_data), 1).T
return data

unpacked_df: pd.DataFrame = df[label_cols].apply(unpack_column)
Expand Down
10 changes: 5 additions & 5 deletions graphium/ipu/ipu_simple_lightning.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
import pytorch_lightning as pl
from pytorch_lightning.strategies import IPUStrategy
from pytorch_lightning.loggers import WandbLogger
import lightning
from lightning_graphcore import IPUStrategy
from lightning.pytorch.loggers import WandbLogger

import torch
from torch import nn
Expand Down Expand Up @@ -60,7 +60,7 @@ def forward(self, x):
# This class shows a minimal lightning example. This example uses our own
# SimpleTorchModel which is a basic 2 conv, 2 FC torch network. It can be
# found in simple_torch_model.py.
class SimpleLightning(pl.LightningModule):
class SimpleLightning(lightning.LightningModule):
def __init__(self, in_dim, hidden_dim, kernel_size, num_classes, on_ipu):
super().__init__()
self.model = SimpleTorchModel(
Expand Down Expand Up @@ -144,7 +144,7 @@ def configure_optimizers(self):
ipus = 1
strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts)

trainer = pl.Trainer(
trainer = lightning.Trainer(
logger=WandbLogger(),
ipus=ipus,
max_epochs=3,
Expand Down
39 changes: 16 additions & 23 deletions graphium/ipu/ipu_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from torch_geometric.data import Batch
from torch import Tensor
from pytorch_lightning.strategies import IPUStrategy
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.trainer.states import RunningStage
from lightning_graphcore import IPUStrategy
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning.pytorch.trainer.states import RunningStage

from graphium.trainer.predictor import PredictorModule
from graphium.ipu.ipu_utils import import_poptorch
Expand All @@ -19,17 +19,6 @@
poptorch = import_poptorch()


class DictIPUStrategy(IPUStrategy):
def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
args = self._prepare_input(args)
args = args[0]
poptorch_model = self.poptorch_models[stage]
self.lightning_module._running_torchscript = True
out = poptorch_model(**args)
self.lightning_module._running_torchscript = False
return out


class PyGArgsParser(poptorch.ICustomArgParser):
"""
This class is responsible for converting a PyG Batch from and to
Expand Down Expand Up @@ -114,7 +103,8 @@ def on_train_batch_end(self, outputs, batch, batch_idx):
outputs["loss"] = outputs["loss"].mean()
super().on_train_batch_end(outputs, batch, batch_idx)

def training_step(self, features, labels) -> Dict[str, Any]:
def training_step(self, batch, batch_idx) -> Dict[str, Any]:
features, labels = batch["features"], batch["labels"]
features, labels = self.squeeze_input_dims(features, labels)
dict_input = {"features": features, "labels": labels}
step_dict = super().training_step(dict_input, to_cpu=False)
Expand All @@ -123,15 +113,17 @@ def training_step(self, features, labels) -> Dict[str, Any]:
step_dict["loss"] = self.poptorch.identity_loss(loss, reduction="mean")
return step_dict

def validation_step(self, features, labels) -> Dict[str, Any]:
def validation_step(self, batch, batch_idx) -> Dict[str, Any]:
features, labels = batch["features"], batch["labels"]
features, labels = self.squeeze_input_dims(features, labels)
dict_input = {"features": features, "labels": labels}
step_dict = super().validation_step(dict_input, to_cpu=False)

return step_dict

def test_step(self, features, labels) -> Dict[str, Any]:
def test_step(self, batch, batch_idx) -> Dict[str, Any]:
# Build a dictionary from the tuples
features, labels = batch["features"], batch["labels"]
features, labels = self.squeeze_input_dims(features, labels)
dict_input = {"features": features, "labels": labels}
step_dict = super().test_step(dict_input, to_cpu=False)
Expand All @@ -145,17 +137,18 @@ def predict_step(self, **inputs) -> Dict[str, Any]:

return step_dict

def validation_epoch_end(self, outputs: Dict[str, Any]):
def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
# convert data that will be tracked
outputs = self.convert_from_fp16(outputs)
super().validation_epoch_end(outputs)
super().on_validation_batch_end(outputs, batch, batch_idx)

def evaluation_epoch_end(self, outputs: Dict[str, Any]):
def evaluation_epoch_end(self, outputs: Any):
outputs = self.convert_from_fp16(outputs)
super().evaluation_epoch_end(outputs)

def test_epoch_end(self, outputs: Dict[str, Any]):
def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
outputs = self.convert_from_fp16(outputs)
super().test_epoch_end(outputs)
super().on_test_batch_end(outputs, batch, batch_idx)

def configure_optimizers(self, impl=None):
if impl is None:
Expand Down Expand Up @@ -211,7 +204,7 @@ def _convert_features_dtype(self, feats):
return feats

def precision_to_dtype(self, precision):
return torch.half if precision in (16, "16") else torch.float
return torch.half if precision == "16-true" else torch.float

def get_num_graphs(self, data: Batch):
"""
Expand Down
Loading

0 comments on commit 2566355

Please sign in to comment.