diff --git a/docs/tutorials/model_training/ipu_training.ipynb b/docs/tutorials/model_training/ipu_training.ipynb index dd0090e1c..31687075a 100644 --- a/docs/tutorials/model_training/ipu_training.ipynb +++ b/docs/tutorials/model_training/ipu_training.ipynb @@ -653,7 +653,7 @@ "from omegaconf import DictConfig\n", "import timeit\n", "from loguru import logger\n", - "from pytorch_lightning.utilities.model_summary import ModelSummary\n", + "from lightning.pytorch.utilities.model_summary import ModelSummary\n", "\n", "# Current project imports\n", "import graphium\n", diff --git a/docs/tutorials/model_training/ipu_training_demo.py b/docs/tutorials/model_training/ipu_training_demo.py index 250e556ae..dcdec5ea1 100644 --- a/docs/tutorials/model_training/ipu_training_demo.py +++ b/docs/tutorials/model_training/ipu_training_demo.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig import timeit from loguru import logger -from pytorch_lightning.utilities.model_summary import ModelSummary +from lightning.pytorch.utilities.model_summary import ModelSummary # Current project imports import graphium diff --git a/env.yml b/env.yml index 56a690852..69f8d6850 100644 --- a/env.yml +++ b/env.yml @@ -16,6 +16,7 @@ dependencies: - pandas >=1.0 - scikit-learn - fastparquet + - sympy # viz - matplotlib >=3.0.1 @@ -30,8 +31,9 @@ dependencies: # ML packages - cudatoolkit # works also with CPU-only system. - pytorch >=1.10.2,<2.0 + - tensorboard + - lightning >=2.0 - torchvision - - pytorch-lightning >=1.9 - torchmetrics >=0.7.0,<0.11 - ogb - pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric` @@ -64,3 +66,7 @@ dependencies: - mkdocs-click - markdown-include - mike >=1.0.0 + + - pip + - pip: + - lightning-graphcore # optional, for using IPUs only \ No newline at end of file diff --git a/expts/configs/config_mpnn_10M_pcqm4m.yaml b/expts/configs/config_mpnn_10M_pcqm4m.yaml index de619cf2e..63fab1970 100644 --- a/expts/configs/config_mpnn_10M_pcqm4m.yaml +++ b/expts/configs/config_mpnn_10M_pcqm4m.yaml @@ -25,7 +25,7 @@ accelerator: loss_scaling: 1024 trainer: trainer: - precision: 16 + precision: 16-true accumulate_grad_batches: 4 ipu_config: diff --git a/expts/main_run_get_fingerprints.py b/expts/main_run_get_fingerprints.py index 67f3cc015..94d7f066e 100644 --- a/expts/main_run_get_fingerprints.py +++ b/expts/main_run_get_fingerprints.py @@ -6,7 +6,7 @@ import pandas as pd import torch import fsspec -from pytorch_lightning.utilities.model_summary import ModelSummary +from lightning.pytorch.utilities.model_summary import ModelSummary # Current project imports import graphium diff --git a/expts/main_run_multitask.py b/expts/main_run_multitask.py index 04b11aa39..569c9b4be 100644 --- a/expts/main_run_multitask.py +++ b/expts/main_run_multitask.py @@ -8,7 +8,7 @@ import timeit from loguru import logger from datetime import datetime -from pytorch_lightning.utilities.model_summary import ModelSummary +from lightning.pytorch.utilities.model_summary import ModelSummary # Current project imports import graphium diff --git a/expts/main_run_predict.py b/expts/main_run_predict.py index 0c23e39c9..60d0b6513 100644 --- a/expts/main_run_predict.py +++ b/expts/main_run_predict.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import torch -from pytorch_lightning.utilities.model_summary import ModelSummary +from lightning.pytorch.utilities.model_summary import ModelSummary # Current project imports import graphium diff --git a/expts/main_run_test.py b/expts/main_run_test.py index c3bc5a62d..fd247044b 100644 --- a/expts/main_run_test.py +++ b/expts/main_run_test.py @@ -4,7 +4,7 @@ import yaml from copy import deepcopy from omegaconf import DictConfig -from pytorch_lightning.utilities.model_summary import ModelSummary +from lightning.pytorch.utilities.model_summary import ModelSummary # Current project imports import graphium diff --git a/graphium/config/_loader.py b/graphium/config/_loader.py index 76441598a..be85a18fd 100644 --- a/graphium/config/_loader.py +++ b/graphium/config/_loader.py @@ -15,9 +15,9 @@ import mup # Lightning -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import WandbLogger, Logger +from lightning import Trainer +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger, Logger # Graphium from graphium.utils.mup import set_base_shapes @@ -31,37 +31,24 @@ from graphium.data.datamodule import MultitaskFromSmilesDataModule, BaseDataModule from graphium.utils.command_line_utils import update_config, get_anchors_and_aliases -# Weights and Biases -from pytorch_lightning import Trainer - def get_accelerator( config_acc: Union[omegaconf.DictConfig, Dict[str, Any]], ) -> str: """ Get the accelerator from the config file, and ensure that they are - consistant. For example, specifying `cpu` as the accelerators, but - `gpus>0` as a Trainer option will yield an error. + consistant. """ # Get the accelerator type accelerator_type = config_acc["type"] # Get the GPU info - gpus = config_acc["config_override"].get("trainer", {}).get("trainer", {}).get("gpus", 0) - if gpus > 0: - assert (accelerator_type is None) or (accelerator_type == "gpu"), "Accelerator mismatch" - accelerator_type = "gpu" - if (accelerator_type == "gpu") and (not torch.cuda.is_available()): logger.warning(f"GPUs selected, but will be ignored since no GPU are available on this device") accelerator_type = "cpu" # Get the IPU info - ipus = config_acc["config_override"].get("trainer", {}).get("trainer", {}).get("ipus", 0) - if ipus > 0: - assert (accelerator_type is None) or (accelerator_type == "ipu"), "Accelerator mismatch" - accelerator_type = "ipu" if accelerator_type == "ipu": poptorch = import_poptorch() if not poptorch.ipuHardwareIsAvailable(): @@ -280,7 +267,7 @@ def load_predictor( task_norms: Optional[Dict[Callable, Any]] = None, ) -> PredictorModule: """ - Defining the predictor module, which handles the training logic from `pytorch_lightning.LighningModule` + Defining the predictor module, which handles the training logic from `lightning.LighningModule` Parameters: model_class: The torch Module containing the main forward function accelerator_type: The accelerator type, e.g. "cpu", "gpu", "ipu" @@ -365,7 +352,7 @@ def load_trainer( cfg_trainer = deepcopy(config["trainer"]) # Define the IPU plugin if required - strategy = None + strategy = "auto" if accelerator_type == "ipu": ipu_opts, ipu_inference_opts = _get_ipu_opts(config) @@ -377,22 +364,14 @@ def load_trainer( gradient_accumulation=config["trainer"]["trainer"].get("accumulate_grad_batches", None), ) - from graphium.ipu.ipu_wrapper import DictIPUStrategy + from lightning_graphcore import IPUStrategy - strategy = DictIPUStrategy(training_opts=training_opts, inference_opts=inference_opts) + strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts) - # Set the number of gpus to 0 if no GPU is available - _ = cfg_trainer["trainer"].pop("accelerator", None) - gpus = cfg_trainer["trainer"].pop("gpus", None) - ipus = cfg_trainer["trainer"].pop("ipus", None) - if (accelerator_type == "gpu") and (gpus is None): - gpus = 1 - if (accelerator_type == "ipu") and (ipus is None): - ipus = 1 - if accelerator_type != "gpu": - gpus = 0 - if accelerator_type != "ipu": - ipus = 0 + # Get devices + devices = cfg_trainer["trainer"].pop("devices", 1) + if accelerator_type == "ipu": + devices = 1 # number of IPUs used is defined in the ipu options files # Remove the gradient accumulation from IPUs, since it's handled by the device if accelerator_type == "ipu": @@ -422,8 +401,7 @@ def load_trainer( detect_anomaly=True, strategy=strategy, accelerator=accelerator_type, - ipus=ipus, - gpus=gpus, + devices=devices, **cfg_trainer["trainer"], **trainer_kwargs, ) diff --git a/graphium/data/datamodule.py b/graphium/data/datamodule.py index 2d17e09a3..ed2fdcc2a 100644 --- a/graphium/data/datamodule.py +++ b/graphium/data/datamodule.py @@ -24,8 +24,8 @@ from sklearn.model_selection import train_test_split -import pytorch_lightning as pl -from pytorch_lightning.trainer.states import RunningStage +import lightning +from lightning.pytorch.trainer.states import RunningStage import torch from torch_geometric.data import Data @@ -81,7 +81,7 @@ ) -class BaseDataModule(pl.LightningDataModule): +class BaseDataModule(lightning.LightningDataModule): def __init__( self, batch_size_training: int = 16, diff --git a/graphium/data/multilevel_utils.py b/graphium/data/multilevel_utils.py index de8ef96c1..388bcbeb9 100644 --- a/graphium/data/multilevel_utils.py +++ b/graphium/data/multilevel_utils.py @@ -35,7 +35,7 @@ def merge_columns(data: pd.Series): data = data.to_list() data = [np.array([np.nan]) if not isinstance(d, np.ndarray) and math.isnan(d) else d for d in data] padded_data = itertools.zip_longest(*data, fillvalue=np.nan) - data = np.stack(padded_data, 1).T + data = np.stack(list(padded_data), 1).T return data unpacked_df: pd.DataFrame = df[label_cols].apply(unpack_column) diff --git a/graphium/ipu/ipu_simple_lightning.py b/graphium/ipu/ipu_simple_lightning.py index b6414e205..8f85e3444 100644 --- a/graphium/ipu/ipu_simple_lightning.py +++ b/graphium/ipu/ipu_simple_lightning.py @@ -1,7 +1,7 @@ # Copyright (c) 2021 Graphcore Ltd. All rights reserved. -import pytorch_lightning as pl -from pytorch_lightning.strategies import IPUStrategy -from pytorch_lightning.loggers import WandbLogger +import lightning +from lightning_graphcore import IPUStrategy +from lightning.pytorch.loggers import WandbLogger import torch from torch import nn @@ -60,7 +60,7 @@ def forward(self, x): # This class shows a minimal lightning example. This example uses our own # SimpleTorchModel which is a basic 2 conv, 2 FC torch network. It can be # found in simple_torch_model.py. -class SimpleLightning(pl.LightningModule): +class SimpleLightning(lightning.LightningModule): def __init__(self, in_dim, hidden_dim, kernel_size, num_classes, on_ipu): super().__init__() self.model = SimpleTorchModel( @@ -144,7 +144,7 @@ def configure_optimizers(self): ipus = 1 strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts) - trainer = pl.Trainer( + trainer = lightning.Trainer( logger=WandbLogger(), ipus=ipus, max_epochs=3, diff --git a/graphium/ipu/ipu_wrapper.py b/graphium/ipu/ipu_wrapper.py index 262d5b338..59b42f82d 100644 --- a/graphium/ipu/ipu_wrapper.py +++ b/graphium/ipu/ipu_wrapper.py @@ -2,9 +2,9 @@ from torch_geometric.data import Batch from torch import Tensor -from pytorch_lightning.strategies import IPUStrategy -from pytorch_lightning.utilities.types import STEP_OUTPUT -from pytorch_lightning.trainer.states import RunningStage +from lightning_graphcore import IPUStrategy +from lightning.pytorch.utilities.types import STEP_OUTPUT +from lightning.pytorch.trainer.states import RunningStage from graphium.trainer.predictor import PredictorModule from graphium.ipu.ipu_utils import import_poptorch @@ -19,17 +19,6 @@ poptorch = import_poptorch() -class DictIPUStrategy(IPUStrategy): - def _step(self, stage: RunningStage, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - args = self._prepare_input(args) - args = args[0] - poptorch_model = self.poptorch_models[stage] - self.lightning_module._running_torchscript = True - out = poptorch_model(**args) - self.lightning_module._running_torchscript = False - return out - - class PyGArgsParser(poptorch.ICustomArgParser): """ This class is responsible for converting a PyG Batch from and to @@ -114,7 +103,8 @@ def on_train_batch_end(self, outputs, batch, batch_idx): outputs["loss"] = outputs["loss"].mean() super().on_train_batch_end(outputs, batch, batch_idx) - def training_step(self, features, labels) -> Dict[str, Any]: + def training_step(self, batch, batch_idx) -> Dict[str, Any]: + features, labels = batch["features"], batch["labels"] features, labels = self.squeeze_input_dims(features, labels) dict_input = {"features": features, "labels": labels} step_dict = super().training_step(dict_input, to_cpu=False) @@ -123,15 +113,17 @@ def training_step(self, features, labels) -> Dict[str, Any]: step_dict["loss"] = self.poptorch.identity_loss(loss, reduction="mean") return step_dict - def validation_step(self, features, labels) -> Dict[str, Any]: + def validation_step(self, batch, batch_idx) -> Dict[str, Any]: + features, labels = batch["features"], batch["labels"] features, labels = self.squeeze_input_dims(features, labels) dict_input = {"features": features, "labels": labels} step_dict = super().validation_step(dict_input, to_cpu=False) return step_dict - def test_step(self, features, labels) -> Dict[str, Any]: + def test_step(self, batch, batch_idx) -> Dict[str, Any]: # Build a dictionary from the tuples + features, labels = batch["features"], batch["labels"] features, labels = self.squeeze_input_dims(features, labels) dict_input = {"features": features, "labels": labels} step_dict = super().test_step(dict_input, to_cpu=False) @@ -145,17 +137,18 @@ def predict_step(self, **inputs) -> Dict[str, Any]: return step_dict - def validation_epoch_end(self, outputs: Dict[str, Any]): + def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: + # convert data that will be tracked outputs = self.convert_from_fp16(outputs) - super().validation_epoch_end(outputs) + super().on_validation_batch_end(outputs, batch, batch_idx) - def evaluation_epoch_end(self, outputs: Dict[str, Any]): + def evaluation_epoch_end(self, outputs: Any): outputs = self.convert_from_fp16(outputs) super().evaluation_epoch_end(outputs) - def test_epoch_end(self, outputs: Dict[str, Any]): + def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: outputs = self.convert_from_fp16(outputs) - super().test_epoch_end(outputs) + super().on_test_batch_end(outputs, batch, batch_idx) def configure_optimizers(self, impl=None): if impl is None: @@ -211,7 +204,7 @@ def _convert_features_dtype(self, feats): return feats def precision_to_dtype(self, precision): - return torch.half if precision in (16, "16") else torch.float + return torch.half if precision == "16-true" else torch.float def get_num_graphs(self, data: Batch): """ diff --git a/graphium/trainer/predictor.py b/graphium/trainer/predictor.py index bfe122520..aff02144f 100644 --- a/graphium/trainer/predictor.py +++ b/graphium/trainer/predictor.py @@ -7,7 +7,7 @@ import torch from torch import nn, Tensor -import pytorch_lightning as pl +import lightning from torch_geometric.data import Data, Batch from mup.optim import MuAdam @@ -22,7 +22,7 @@ } -class PredictorModule(pl.LightningModule): +class PredictorModule(lightning.LightningModule): def __init__( self, model_class: Type[nn.Module], @@ -158,6 +158,8 @@ def __init__( # throughput estimation self.mean_val_time_tracker = MovingAverageTracker() self.mean_val_tput_tracker = MovingAverageTracker() + self.validation_step_outputs = [] + self.test_step_outputs = [] self.epoch_start_time = None def forward( @@ -550,47 +552,46 @@ def on_train_epoch_end(self) -> None: self.epoch_start_time = None self.log("epoch_time", torch.tensor(epoch_time)) - def training_epoch_end(self, outputs: Dict): - """ - Nothing happens at the end of the training epoch. - It serves no purpose to do a general step for the training, - but it can explode the RAM when using a large dataset. - """ - pass - def on_validation_epoch_start(self) -> None: self.mean_val_time_tracker.reset() self.mean_val_tput_tracker.reset() return super().on_validation_epoch_start() - def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_validation_batch_start(self, batch: Any, batch_idx: int) -> None: self.validation_batch_start_time = time.time() - return super().on_validation_batch_start(batch, batch_idx, dataloader_idx) + return super().on_validation_batch_start(batch, batch_idx) - def on_validation_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_validation_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: val_batch_time = time.time() - self.validation_batch_start_time + self.validation_step_outputs.append(outputs) self.mean_val_time_tracker.update(val_batch_time) num_graphs = self.get_num_graphs(batch["features"]) self.mean_val_tput_tracker.update(num_graphs / val_batch_time) - return super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx) + return super().on_validation_batch_end(outputs, batch, batch_idx) - def validation_epoch_end(self, outputs: Dict[str, Any]): - metrics_logs = self._general_epoch_end(outputs=outputs, step_name="val") + def on_validation_epoch_end(self) -> None: + metrics_logs = self._general_epoch_end(outputs=self.validation_step_outputs, step_name="val") + self.validation_step_outputs.clear() concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) - concatenated_metrics_logs["val/mean_time"] = self.mean_val_time_tracker.mean_value + concatenated_metrics_logs["val/mean_time"] = torch.tensor(self.mean_val_time_tracker.mean_value) concatenated_metrics_logs["val/mean_tput"] = self.mean_val_tput_tracker.mean_value + if hasattr(self.optimizers(), "param_groups"): lr = self.optimizers().param_groups[0]["lr"] - concatenated_metrics_logs["lr"] = lr - concatenated_metrics_logs["n_epochs"] = self.current_epoch + concatenated_metrics_logs["lr"] = torch.tensor(lr) + concatenated_metrics_logs["n_epochs"] = torch.tensor(self.current_epoch, dtype=torch.float32) self.log_dict(concatenated_metrics_logs) # Save yaml file with the per-task metrics summaries full_dict = {} full_dict.update(self.task_epoch_summary.get_dict_summary()) - def test_epoch_end(self, outputs: Dict[str, Any]): - metrics_logs = self._general_epoch_end(outputs=outputs, step_name="test") + def on_test_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: + self.test_step_outputs.append(outputs) + + def on_test_epoch_end(self) -> None: + metrics_logs = self._general_epoch_end(outputs=self.test_step_outputs, step_name="test") + self.test_step_outputs.clear() concatenated_metrics_logs = self.task_epoch_summary.concatenate_metrics_logs(metrics_logs) self.log_dict(concatenated_metrics_logs) diff --git a/notebooks/dev-datamodule-invalidate-cache.ipynb b/notebooks/dev-datamodule-invalidate-cache.ipynb index a9fca0c8c..515ad943b 100644 --- a/notebooks/dev-datamodule-invalidate-cache.ipynb +++ b/notebooks/dev-datamodule-invalidate-cache.ipynb @@ -22,7 +22,7 @@ "import tempfile\n", "\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning\n", "import torch\n", "import datamol as dm\n", "\n", diff --git a/notebooks/dev-datamodule-ogb.ipynb b/notebooks/dev-datamodule-ogb.ipynb index 0822fbc70..ea3772c1d 100644 --- a/notebooks/dev-datamodule-ogb.ipynb +++ b/notebooks/dev-datamodule-ogb.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -49,7 +50,7 @@ "\n", "import numpy as np\n", "import pandas as pd\n", - "import pytorch_lightning as pl\n", + "import lightning\n", "import torch\n", "import datamol as dm\n", "\n", diff --git a/notebooks/dev-datamodule.ipynb b/notebooks/dev-datamodule.ipynb index 69491d88f..91c21cdfd 100644 --- a/notebooks/dev-datamodule.ipynb +++ b/notebooks/dev-datamodule.ipynb @@ -24,7 +24,7 @@ "import tempfile\n", "\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning\n", "import torch\n", "import datamol as dm\n", "\n", @@ -177,6 +177,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ diff --git a/notebooks/dev-pretrained.ipynb b/notebooks/dev-pretrained.ipynb index 1cb2b5677..bdaaba6fa 100644 --- a/notebooks/dev-pretrained.ipynb +++ b/notebooks/dev-pretrained.ipynb @@ -25,7 +25,7 @@ "from loguru import logger\n", "\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning\n", "import torch\n", "import datamol as dm\n", "import pandas as pd\n", @@ -34,6 +34,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -212,6 +213,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -319,6 +321,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ diff --git a/notebooks/dev.ipynb b/notebooks/dev.ipynb index e84324e2a..3d466281a 100644 --- a/notebooks/dev.ipynb +++ b/notebooks/dev.ipynb @@ -25,7 +25,7 @@ "from loguru import logger\n", "\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning\n", "import torch\n", "import datamol as dm\n", "import pandas as pd\n", diff --git a/profiling/profile_predictor.py b/profiling/profile_predictor.py index 0e3828d05..be9810d00 100644 --- a/profiling/profile_predictor.py +++ b/profiling/profile_predictor.py @@ -11,7 +11,7 @@ load_predictor, load_architecture, ) -from pytorch_lightning import Trainer +from lightning import Trainer def main(): diff --git a/requirements_ipu.txt b/requirements_ipu.txt index b884153ed..cdda07668 100644 --- a/requirements_ipu.txt +++ b/requirements_ipu.txt @@ -7,6 +7,7 @@ loguru tqdm numpy scipy >==1.4 +sympy pandas >==1.0 scikit-learn seaborn @@ -24,6 +25,7 @@ mordred umap-learn pytest >==6.0 pytest-cov +pytest-xdist black >=23 jupyterlab ipywidgets @@ -53,4 +55,5 @@ fastparquet torch-scatter==2.1.0 torch-sparse==0.6.15 torchvision==0.14.1+cpu -git+https://github.com/joao-alex-cunha/lightning \ No newline at end of file +lightning @ git+https://github.com/Lightning-AI/lightning@ca30fd7752582201a3966806c92e3acbbaf2a045 +lightning-graphcore @ git+https://github.com/Lightning-AI/lightning-Graphcore diff --git a/tests/config_test_ipu_dataloader.yaml b/tests/config_test_ipu_dataloader.yaml index ea90ee2d0..a3606c3d4 100644 --- a/tests/config_test_ipu_dataloader.yaml +++ b/tests/config_test_ipu_dataloader.yaml @@ -45,6 +45,7 @@ datamodule: idx_col: null # This may not always be provided weights_col: null # This may not always be provided weights_type: null # This may not always be provided + task_level: graph alpha: df: null df_path: *df_path @@ -58,6 +59,7 @@ datamodule: idx_col: null # This may not always be provided weights_col: null # This may not always be provided weights_type: null # This may not always be provided + task_level: graph # Featurization prepare_dict_or_graph: pyg:graph featurization_n_jobs: 0 diff --git a/tests/test_ipu_dataloader.py b/tests/test_ipu_dataloader.py index 6b94a67e2..08550a714 100644 --- a/tests/test_ipu_dataloader.py +++ b/tests/test_ipu_dataloader.py @@ -4,8 +4,8 @@ import numpy as np from copy import deepcopy from warnings import warn -from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.strategies import IPUStrategy +from lightning import Trainer, LightningModule +from lightning_graphcore import IPUStrategy from functools import partial import torch @@ -159,7 +159,8 @@ def test_poptorch_simple_deviceiterations_gradient_accumulation(self): max_epochs=2, strategy=strategy, num_sanity_val_steps=0, - ipus=1, + accelerator="ipu", + devices=1, ) trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) @@ -239,18 +240,6 @@ def assert_shapes(self, dict_input, step): def get_progress_bar_dict(self): return {} - def on_train_batch_end(self, *args, **kwargs): - return - - def on_validation_batch_end(self, *args, **kwargs): - return - - def validation_epoch_end(self, *args, **kwargs): - return - - def on_train_epoch_end(self) -> None: - return - def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) @@ -264,8 +253,6 @@ def squeeze_input_dims(self, features, labels): return features, labels - from graphium.ipu.ipu_wrapper import DictIPUStrategy - gradient_accumulation = 3 device_iterations = 5 batch_size = 7 @@ -306,7 +293,7 @@ def squeeze_input_dims(self, features, labels): metrics=metrics, **cfg["predictor"], ) - strategy = DictIPUStrategy(training_opts=training_opts, inference_opts=inference_opts) + strategy = IPUStrategy(training_opts=training_opts, inference_opts=inference_opts) trainer = Trainer( logger=False, enable_checkpointing=False, max_epochs=2, strategy=strategy, num_sanity_val_steps=0 )