diff --git a/configs/model/mnist.yaml b/configs/model/mnist.yaml index 6f9c2fa..68c980e 100644 --- a/configs/model/mnist.yaml +++ b/configs/model/mnist.yaml @@ -22,4 +22,4 @@ net: output_size: 10 # compile model for faster training with pytorch 2.0 -compile: false +compile_model: false diff --git a/src/eval.py b/src/eval.py index fa35c09..9ec50bf 100644 --- a/src/eval.py +++ b/src/eval.py @@ -1,11 +1,13 @@ """Main evaluation script.""" -from typing import Any +from typing import TYPE_CHECKING, Any import hydra import torch -from lightning import LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger + +if TYPE_CHECKING: + from lightning import LightningDataModule, LightningModule, Trainer + from lightning.pytorch.loggers import Logger from omegaconf import DictConfig from src.utils import ( diff --git a/src/models/components/simple_dense_net.py b/src/models/components/simple_dense_net.py index f33f6a0..806b389 100644 --- a/src/models/components/simple_dense_net.py +++ b/src/models/components/simple_dense_net.py @@ -1,3 +1,5 @@ +"""Simple dense neural network.""" + import torch from torch import nn diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index fad8d8f..908f70d 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -1,3 +1,5 @@ +"""Mnist simple model.""" + from typing import Any import torch @@ -44,7 +46,7 @@ def __init__( net: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, - compile: bool, + compile_model: bool, ) -> None: """Initialize a `MNISTLitModule`. @@ -52,6 +54,7 @@ def __init__( net: The model to train. optimizer: The optimizer to use for training. scheduler: The learning rate scheduler to use for training. + compile_model: Whether or not compile the model. """ super().__init__() @@ -185,8 +188,7 @@ def on_test_epoch_end(self) -> None: pass def setup(self, stage: str) -> None: - """Lightning hook that is called at the beginning of fit (train + validate), validate, - test, or predict. + """Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. @@ -194,11 +196,12 @@ def setup(self, stage: str) -> None: Args: stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. """ - if self.hparams.compile and stage == "fit": + if self.hparams.compile_model and stage == "fit": self.net = torch.compile(self.net) def configure_optimizers(self) -> dict[str, Any]: """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. Examples: diff --git a/src/train.py b/src/train.py index 65aed31..c6e5ddb 100644 --- a/src/train.py +++ b/src/train.py @@ -1,12 +1,15 @@ """Main training script.""" -from typing import Any +from typing import TYPE_CHECKING, Any import hydra -import lightning as L +import lightning import torch -from lightning import Callback, LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger + +if TYPE_CHECKING: + from lightning import Callback, LightningDataModule, LightningModule, Trainer + from lightning.pytorch.loggers import Logger + from omegaconf import DictConfig from src.utils import ( @@ -37,7 +40,7 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: """ # set seed for random number generators in pytorch, numpy and python.random if cfg.get("seed"): - L.seed_everything(cfg.seed, workers=True) + lightning.seed_everything(cfg.seed, workers=True) log.info(f"Instantiating datamodule <{cfg.data._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 8c09795..2c68c15 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,7 +1,7 @@ """This module contains utility functions and classes for the project.""" -from src.utils.instantiators import instantiate_callbacks, instantiate_loggers -from src.utils.logging_utils import log_hyperparameters -from src.utils.pylogger import RankedLogger -from src.utils.rich_utils import enforce_tags, print_config_tree -from src.utils.utils import extras, get_metric_value, task_wrapper +from src.utils.instantiators import instantiate_callbacks, instantiate_loggers # noqa +from src.utils.logging_utils import log_hyperparameters # noqa +from src.utils.pylogger import RankedLogger # noqa +from src.utils.rich_utils import enforce_tags, print_config_tree # noqa +from src.utils.utils import extras, get_metric_value, task_wrapper # noqa diff --git a/src/utils/download_utils.py b/src/utils/download_utils.py index 8ccabae..40aca8b 100644 --- a/src/utils/download_utils.py +++ b/src/utils/download_utils.py @@ -1,3 +1,5 @@ +"""Utility functions aimed at downloading any data from external sources.""" + import cloudpathlib from src.utils import RankedLogger diff --git a/src/utils/instantiators.py b/src/utils/instantiators.py index b28cf09..134b921 100644 --- a/src/utils/instantiators.py +++ b/src/utils/instantiators.py @@ -1,3 +1,5 @@ +"""Module to instantiate different objects types.""" + import hydra from lightning import Callback from lightning.pytorch.loggers import Logger @@ -21,7 +23,7 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]: return callbacks if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") + raise TypeError("Callbacks config must be a DictConfig!") # noqa: TRY003 for _, cb_conf in callbacks_cfg.items(): if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: @@ -44,7 +46,7 @@ def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]: return logger if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") + raise TypeError("Logger config must be a DictConfig!") # noqa: TRY003 for _, lg_conf in logger_cfg.items(): if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py index 9259a68..d57853c 100644 --- a/src/utils/logging_utils.py +++ b/src/utils/logging_utils.py @@ -1,3 +1,5 @@ +"""Logging utility instantiator.""" + from typing import Any from lightning_utilities.core.rank_zero import rank_zero_only diff --git a/src/utils/pylogger.py b/src/utils/pylogger.py index 712ea22..de7600f 100644 --- a/src/utils/pylogger.py +++ b/src/utils/pylogger.py @@ -1,3 +1,5 @@ +"""Code for logging on multi-GPU-friendly.""" + import logging from collections.abc import Mapping diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py index 3d72547..94d9942 100644 --- a/src/utils/rich_utils.py +++ b/src/utils/rich_utils.py @@ -1,3 +1,5 @@ +"""Rich utils to print config tree.""" + from collections.abc import Sequence from pathlib import Path @@ -85,7 +87,7 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: """ if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: - raise ValueError("Specify tags before launching a multirun!") + raise ValueError("Specify tags before launching a multirun!") # noqa log.warning("No tags provided in config. Prompting user to input tags...") tags = Prompt.ask("Enter a list of comma separated tags", default="dev") diff --git a/src/utils/utils.py b/src/utils/utils.py index 14162f3..04555e6 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -78,14 +78,14 @@ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: metric_dict, object_dict = task_func(cfg=cfg) # things to do if exception occurs - except Exception as ex: + except Exception as e: # save exception to `.log` file log.exception("") # some hyperparameter combinations might be invalid or cause out-of-memory errors # so when using hparam search plugins like Optuna, you might want to disable # raising the below exception to avoid multirun failure - raise ex + raise e # noqa: TRY201 # things to always do after either success or exception finally: @@ -120,11 +120,7 @@ def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> No return None if metric_name not in metric_dict: - raise Exception( - f"Metric value not found! \n" - "Make sure metric name logged in LightningModule is correct!\n" - "Make sure `optimized_metric` name in `hparams_search` config is correct!" - ) + raise ValueError(f"Metric value not found! \n") # noqa: TRY003 metric_value = metric_dict[metric_name].item() log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") diff --git a/tests/conftest.py b/tests/conftest.py index e69de29..51c165c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1 @@ +"""Fixtures for your unit tests."""