Skip to content

Commit

Permalink
lint: pre-commit linter application
Browse files Browse the repository at this point in the history
  • Loading branch information
rayanramoul committed Oct 16, 2024
1 parent 82fbb10 commit ee05820
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 28 deletions.
2 changes: 1 addition & 1 deletion configs/model/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ net:
output_size: 10

# compile model for faster training with pytorch 2.0
compile: false
compile_model: false
8 changes: 5 additions & 3 deletions src/eval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Main evaluation script."""

from typing import Any
from typing import TYPE_CHECKING, Any

import hydra
import torch
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger

if TYPE_CHECKING:
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig

from src.utils import (
Expand Down
2 changes: 2 additions & 0 deletions src/models/components/simple_dense_net.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Simple dense neural network."""

import torch
from torch import nn

Expand Down
11 changes: 7 additions & 4 deletions src/models/mnist_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Mnist simple model."""

from typing import Any

import torch
Expand Down Expand Up @@ -44,14 +46,15 @@ def __init__(
net: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
compile: bool,
compile_model: bool,
) -> None:
"""Initialize a `MNISTLitModule`.
Args:
net: The model to train.
optimizer: The optimizer to use for training.
scheduler: The learning rate scheduler to use for training.
compile_model: Whether or not compile the model.
"""
super().__init__()

Expand Down Expand Up @@ -185,20 +188,20 @@ def on_test_epoch_end(self) -> None:
pass

def setup(self, stage: str) -> None:
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
test, or predict.
"""Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
This is a good hook when you need to build models dynamically or adjust something about
them. This hook is called on every process when using DDP.
Args:
stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
"""
if self.hparams.compile and stage == "fit":
if self.hparams.compile_model and stage == "fit":
self.net = torch.compile(self.net)

def configure_optimizers(self) -> dict[str, Any]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples:
Expand Down
13 changes: 8 additions & 5 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Main training script."""

from typing import Any
from typing import TYPE_CHECKING, Any

import hydra
import lightning as L
import lightning
import torch
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger

if TYPE_CHECKING:
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger

from omegaconf import DictConfig

from src.utils import (
Expand Down Expand Up @@ -37,7 +40,7 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
lightning.seed_everything(cfg.seed, workers=True)

log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
Expand Down
10 changes: 5 additions & 5 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module contains utility functions and classes for the project."""

from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
from src.utils.logging_utils import log_hyperparameters
from src.utils.pylogger import RankedLogger
from src.utils.rich_utils import enforce_tags, print_config_tree
from src.utils.utils import extras, get_metric_value, task_wrapper
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers # noqa
from src.utils.logging_utils import log_hyperparameters # noqa
from src.utils.pylogger import RankedLogger # noqa
from src.utils.rich_utils import enforce_tags, print_config_tree # noqa
from src.utils.utils import extras, get_metric_value, task_wrapper # noqa
2 changes: 2 additions & 0 deletions src/utils/download_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utility functions aimed at downloading any data from external sources."""

import cloudpathlib

from src.utils import RankedLogger
Expand Down
6 changes: 4 additions & 2 deletions src/utils/instantiators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Module to instantiate different objects types."""

import hydra
from lightning import Callback
from lightning.pytorch.loggers import Logger
Expand All @@ -21,7 +23,7 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]:
return callbacks

if not isinstance(callbacks_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
raise TypeError("Callbacks config must be a DictConfig!") # noqa: TRY003

for _, cb_conf in callbacks_cfg.items():
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
Expand All @@ -44,7 +46,7 @@ def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]:
return logger

if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
raise TypeError("Logger config must be a DictConfig!") # noqa: TRY003

for _, lg_conf in logger_cfg.items():
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
Expand Down
2 changes: 2 additions & 0 deletions src/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Logging utility instantiator."""

from typing import Any

from lightning_utilities.core.rank_zero import rank_zero_only
Expand Down
2 changes: 2 additions & 0 deletions src/utils/pylogger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Code for logging on multi-GPU-friendly."""

import logging
from collections.abc import Mapping

Expand Down
4 changes: 3 additions & 1 deletion src/utils/rich_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Rich utils to print config tree."""

from collections.abc import Sequence
from pathlib import Path

Expand Down Expand Up @@ -85,7 +87,7 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
"""
if not cfg.get("tags"):
if "id" in HydraConfig().cfg.hydra.job:
raise ValueError("Specify tags before launching a multirun!")
raise ValueError("Specify tags before launching a multirun!") # noqa

log.warning("No tags provided in config. Prompting user to input tags...")
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
Expand Down
10 changes: 3 additions & 7 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
metric_dict, object_dict = task_func(cfg=cfg)

# things to do if exception occurs
except Exception as ex:
except Exception as e:
# save exception to `.log` file
log.exception("")

# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
raise e # noqa: TRY201

# things to always do after either success or exception
finally:
Expand Down Expand Up @@ -120,11 +120,7 @@ def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> No
return None

if metric_name not in metric_dict:
raise Exception(
f"Metric value not found! <metric_name={metric_name}>\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
raise ValueError(f"Metric value not found! <metric_name={metric_name}>\n") # noqa: TRY003

metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Fixtures for your unit tests."""

0 comments on commit ee05820

Please sign in to comment.