diff --git a/index.html b/index.html index 2d06a82..d58dbe0 100644 --- a/index.html +++ b/index.html @@ -419,18 +419,27 @@ - - +
  • + + + 📚 Documentation + +
  • - -
  • + +
  • 🌳 Tree Explained +
  • + + + +
  • @@ -1205,7 +1214,16 @@

    💡 Best practicesUV is powerful (multi-thread, package graph solving, rust backend, etc.) use it as much as you can.

  • -

    🌳 Tree Explained#

    +

    📚 Documentation#

    +

    You have the possibility to generate a documentation website using Mkdocs. It will automatically generate the documentation based on both the markdown files in the docs/ folder and the docstrings in your code. +To generate and serve the documentation locally:

    +
    1
    make serve-docs # Documentation will be available at http://localhost:8000
    +
    +

    And to deploy it to Github pages (youn need to enable Pages in your repository configuration and set it to use the +gh-pages branch):

    +
    1
    make pages-deploy # It will create a gh-pages branch and push the documentation to it
    +
    +

    🌳 Tree Explained#

    ``` . ├── commit-template.txt # use this file to set your commit message template, with make configure-commit template diff --git a/search/search_index.json b/search/search_index.json index af3ee5f..ee82498 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Home","text":"Machine Learning Project Template [![python](https://img.shields.io/badge/-Python_3.8_%7C_3.9_%7C_3.10-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit) [![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) [![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) Click on [Use this template](https://github.com/rayanramoul/ml-project-template/generate) to start your own project! or go to the [Documentation](https://rayanramoul.github.io/ml-project-template/) for more information. A template for machine learning or deep learning projects."},{"location":"#features","title":"\ud83e\udde0 Features","text":"

    "},{"location":"#steps-for-installation","title":"\u2699\ufe0f Steps for Installation","text":""},{"location":"#tips-and-tricks","title":"\ud83e\udd20Tips and Tricks","text":""},{"location":"#how-does-the-project-work","title":"\ud83d\udc0d How does the project work?","text":"

    The train.py or eval.py script is the entry point of the project. It uses Hydra to instantiate the model (LightningModule), dataloader (DataModule), and trainer using the configuration reconstructed using Hydra. The model is then trained or evaluated using Pytorch Lightning.

    "},{"location":"#implementing-your-logic","title":"Implementing your logic","text":"

    You don't need to worry about implementing the training loops, the support for different hardwares, reading of configurations, etc. You need to care about 4 files for each training : your LightningModule (+ its hydra config), your DataModule (+ its hydra config).

    In the LightningModule, you need to implement the following methods:

    Get inspired by the provided examples in the src/data folder.

    Get to know more about Pytorch Lightning's LightningModule and DataModule in the Pytorch Lightning documentation. Finally in the associated configs/ folder, you need to implement the yaml configuration files for the model and dataloader.

    "},{"location":"#the-power-of-hydra","title":"\ud83d\udd0d The power of Hydra","text":"

    As Hydra is used for configuration, you can easily change the hyperparameters of your model, the dataloader, the trainer, etc. by changing the yaml configuration files in the configs/ folder. You can also use the --multirun option to run multiple experiments with different configurations.

    But also, as it used to instantiate the model and dataloader, you can easily change the model, dataloader, or any other component by changing the yaml configuration files or DIRECTLY IN COMMAND LINE. This is especially useful when you want to use different models or dataloaders.

    For example, you can run the following command to train a model with a different architecture, changing the dataset used, and the trainer used:

    uv run src/train.py model=LeNet datamodule=MNISTDataModule trainer=gpu\n

    Read more about Hydra in the official documentation.

    "},{"location":"#best-practices","title":"\ud83d\udca1 Best practices","text":""},{"location":"#tree-explained","title":"\ud83c\udf33 Tree Explained","text":"

    ``` . \u251c\u2500\u2500 commit-template.txt # use this file to set your commit message template, with make configure-commit template \u251c\u2500\u2500 configs # configuration files for hydra \u2502\u00a0\u00a0 \u251c\u2500\u2500 callbacks # configuration files for callbacks \u2502\u00a0\u00a0 \u251c\u2500\u2500 data # configuration files for datamodules \u2502\u00a0\u00a0 \u251c\u2500\u2500 debug # configuration files for pytorch lightning debuggers \u2502\u00a0\u00a0 \u251c\u2500\u2500 eval.yaml # configuration file for evaluation \u2502\u00a0\u00a0 \u251c\u2500\u2500 experiment # configuration files for experiments \u2502\u00a0\u00a0 \u251c\u2500\u2500 extras # configuration files for extra components \u2502\u00a0\u00a0 \u251c\u2500\u2500 hparams_search # configuration files for hyperparameters search \u2502\u00a0\u00a0 \u251c\u2500\u2500 local # configuration files for local training \u2502\u00a0\u00a0 \u251c\u2500\u2500 logger # configuration files for loggers (neptune, wandb, etc.) \u2502\u00a0\u00a0 \u251c\u2500\u2500 model # configuration files for models (LightningModule) \u2502\u00a0\u00a0 \u251c\u2500\u2500 paths # configuration files for paths \u2502\u00a0\u00a0 \u251c\u2500\u2500 trainer # configuration files for trainers (cpu, gpu, tpu) \u2502\u00a0\u00a0 \u2514\u2500\u2500 train.yaml # configuration file for training \u251c\u2500\u2500 data # data folder (to store potentially downloaded datasets) \u251c\u2500\u2500 Makefile # makefile contains useful commands for the project \u251c\u2500\u2500 notebooks # notebooks folder \u251c\u2500\u2500 pyproject.toml # pyproject.toml file for uv package manager \u251c\u2500\u2500 README.md # this file \u251c\u2500\u2500 ruff.toml # ruff.toml file for pre-commit \u251c\u2500\u2500 scripts # scripts folder \u2502\u00a0\u00a0 \u2514\u2500\u2500 example_train.sh \u251c\u2500\u2500 src # source code folder \u2502\u00a0\u00a0 \u251c\u2500\u2500 data # datamodules folder \u2502\u00a0\u00a0 \u2502\u00a0\u00a0 \u251c\u2500\u2500 components \u2502\u00a0\u00a0 \u2502\u00a0\u00a0 \u2514\u2500\u2500 mnist_datamodule.py \u2502\u00a0\u00a0 \u251c\u2500\u2500 eval.py # evaluation entry script \u2502\u00a0\u00a0 \u251c\u2500\u2500 models # models folder (LightningModule) \u2502\u00a0\u00a0 \u2502\u00a0\u00a0 \u251c\u2500\u2500 components # components folder, contains model parts or \"nets\" \u2502\u00a0\u00a0 \u251c\u2500\u2500 train.py # training entry script \u2502\u00a0\u00a0 \u2514\u2500\u2500 utils # utils folder \u2502\u00a0\u00a0 \u251c\u2500\u2500 instantiators.py # instantiators for models and dataloaders \u2502\u00a0\u00a0 \u251c\u2500\u2500 logging_utils.py # logger utils \u2502\u00a0\u00a0 \u251c\u2500\u2500 pylogger.py # multi-process and multi-gpu safe logging \u2502\u00a0\u00a0 \u251c\u2500\u2500 rich_utils.py # rich utils \u2502\u00a0\u00a0 \u2514\u2500\u2500 utils.py # general utils like multi-processing, etc. \u2514\u2500\u2500 tests # tests folder \u2514\u2500\u2500 conftest.py # fixtures for tests \u2514\u2500\u2500 mock_test.py # example of mocking tests

    ````

    "},{"location":"#contributing","title":"\ud83e\udd1d Contributing","text":"

    For more information on how to contribute to this project, please refer to the CONTRIBUTING.md file.

    "},{"location":"#aknowledgements","title":"\ud83c\udf1f Aknowledgements","text":"

    This template was heavily inspired by great existing ones, like:

    But with a few opininated changes and improvements, go check them out!

    "},{"location":"api/eval/","title":"Eval","text":"

    Main evaluation script.

    "},{"location":"api/eval/#src.eval.evaluate","title":"evaluate(cfg)","text":"

    Evaluates given checkpoint on a datamodule testset.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.

    Parameters:

    Name Type Description Default cfg DictConfig

    DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description tuple[dict[str, Any], dict[str, Any]]

    tuple[dict, dict] with metrics and dict with all instantiated objects.

    Source code in src/eval.py
    @task_wrapper\ndef evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:\n    \"\"\"Evaluates given checkpoint on a datamodule testset.\n\n    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during\n    failure. Useful for multiruns, saving info about the crash, etc.\n\n    Args:\n        cfg: DictConfig configuration composed by Hydra.\n\n    Returns:\n        tuple[dict, dict] with metrics and dict with all instantiated objects.\n    \"\"\"\n    assert cfg.ckpt_path\n\n    log.info(f\"Instantiating datamodule <{cfg.data._target_}>\")\n    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)\n\n    log.info(f\"Instantiating model <{cfg.model._target_}>\")\n    model: LightningModule = hydra.utils.instantiate(cfg.model)\n\n    if cfg.get(\"model_compile\", False):\n        log.info(\"Compiling model...\")\n        torch.compile(model)\n\n    log.info(\"Instantiating loggers...\")\n    logger: list[Logger] = instantiate_loggers(cfg.get(\"logger\"))\n\n    log.info(f\"Instantiating trainer <{cfg.trainer._target_}>\")\n    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)\n\n    object_dict = {\n        \"cfg\": cfg,\n        \"datamodule\": datamodule,\n        \"model\": model,\n        \"logger\": logger,\n        \"trainer\": trainer,\n    }\n\n    if logger:\n        log.info(\"Logging hyperparameters!\")\n        log_hyperparameters(object_dict)\n\n    log.info(\"Starting testing!\")\n    trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)\n\n    # for predictions use trainer.predict(...)\n    # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)\n\n    metric_dict = trainer.callback_metrics\n\n    return metric_dict, object_dict\n
    "},{"location":"api/eval/#src.eval.main","title":"main(cfg)","text":"

    Main entry point for evaluation.

    :param cfg: DictConfig configuration composed by Hydra.

    Source code in src/eval.py
    @hydra.main(version_base=\"1.3\", config_path=\"../configs\", config_name=\"eval.yaml\")\ndef main(cfg: DictConfig) -> None:\n    \"\"\"Main entry point for evaluation.\n\n    :param cfg: DictConfig configuration composed by Hydra.\n    \"\"\"\n    # apply extra utilities\n    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)\n    extras(cfg)\n\n    evaluate(cfg)\n
    "},{"location":"api/serve/","title":"Serve","text":"

    Main serve script.

    "},{"location":"api/serve/#src.serve.main","title":"main(cfg)","text":"

    Main entry point for serving.

    Parameters:

    Name Type Description Default cfg DictConfig

    DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description None

    Optional[float] with optimized metric value.

    Source code in src/serve.py
    @hydra.main(version_base=\"1.3\", config_path=\"../configs\", config_name=\"serve.yaml\")\ndef main(cfg: DictConfig) -> None:\n    \"\"\"Main entry point for serving.\n\n    Args:\n        cfg: DictConfig configuration composed by Hydra.\n\n    Returns:\n        Optional[float] with optimized metric value.\n    \"\"\"\n    serve(cfg)\n
    "},{"location":"api/serve/#src.serve.serve","title":"serve(cfg)","text":"

    Serve the specified model in the configuration as a FastAPI api.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description None

    A tuple with metrics and dict with all instantiated objects.

    Source code in src/serve.py
    @task_wrapper\ndef serve(cfg: DictConfig) -> None:\n    \"\"\"Serve the specified model in the configuration as a FastAPI api.\n\n    Args:\n        cfg: A DictConfig configuration composed by Hydra.\n\n    Returns:\n        A tuple with metrics and dict with all instantiated objects.\n    \"\"\"\n    # set seed for random number generators in pytorch, numpy and python.random\n    if cfg.get(\"seed\"):\n        lightning.seed_everything(cfg.seed, workers=True)\n    log.info(f\"Getting model class <{cfg.model._target_}>\")\n    model_class = hydra.utils.get_class(cfg.model._target_)\n    lit_server_api = hydra.utils.instantiate(cfg.serve.api, model_class=model_class)\n    # Create the LitServe server with the MNISTServeAPI\n    server = ls.LitServer(lit_server_api, accelerator=cfg.serve.accelerator, max_batch_size=cfg.serve.max_batch_size)\n    log.info(\"Initialized LitServe server\")\n    # Run the server on port 8000\n    log.info(f\"Starting LitServe server on port {cfg.serve.port}\")\n    server.run(port=cfg.serve.port)\n
    "},{"location":"api/train/","title":"Train","text":"

    Main training script.

    "},{"location":"api/train/#src.train.main","title":"main(cfg)","text":"

    Main entry point for training.

    Parameters:

    Name Type Description Default cfg DictConfig

    DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description float | None

    Optional[float] with optimized metric value.

    Source code in src/train.py
    @hydra.main(version_base=\"1.3\", config_path=\"../configs\", config_name=\"train.yaml\")\ndef main(cfg: DictConfig) -> float | None:\n    \"\"\"Main entry point for training.\n\n    Args:\n        cfg: DictConfig configuration composed by Hydra.\n\n    Returns:\n        Optional[float] with optimized metric value.\n    \"\"\"\n    # apply extra utilities\n    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)\n    extras(cfg)\n\n    # train the model\n    metric_dict, _ = train(cfg)\n\n    # safely retrieve metric value for hydra-based hyperparameter optimization\n    metric_value = get_metric_value(metric_dict=metric_dict, metric_name=cfg.get(\"optimized_metric\"))\n\n    # return optimized metric\n    return metric_value\n
    "},{"location":"api/train/#src.train.train","title":"train(cfg)","text":"

    Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description tuple[dict[str, Any], dict[str, Any]]

    A tuple with metrics and dict with all instantiated objects.

    Source code in src/train.py
    @task_wrapper\ndef train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:\n    \"\"\"Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.\n\n    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during\n    failure. Useful for multiruns, saving info about the crash, etc.\n\n    Args:\n        cfg: A DictConfig configuration composed by Hydra.\n\n    Returns:\n        A tuple with metrics and dict with all instantiated objects.\n    \"\"\"\n    # set seed for random number generators in pytorch, numpy and python.random\n    if cfg.get(\"seed\"):\n        lightning.seed_everything(cfg.seed, workers=True)\n\n    log.info(f\"Instantiating datamodule <{cfg.data._target_}>\")\n    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)\n\n    log.info(f\"Instantiating model <{cfg.model._target_}>\")\n    model: LightningModule = hydra.utils.instantiate(cfg.model)\n\n    if cfg.get(\"model_compile\", False):\n        log.info(\"Compiling model...\")\n        torch.compile(model)\n\n    log.info(\"Instantiating callbacks...\")\n    callbacks: list[Callback] = instantiate_callbacks(cfg.get(\"callbacks\"))\n\n    log.info(\"Instantiating loggers...\")\n    logger: list[Logger] = instantiate_loggers(cfg.get(\"logger\"))\n\n    log.info(f\"Instantiating trainer <{cfg.trainer._target_}>\")\n    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)\n\n    object_dict = {\n        \"cfg\": cfg,\n        \"datamodule\": datamodule,\n        \"model\": model,\n        \"callbacks\": callbacks,\n        \"logger\": logger,\n        \"trainer\": trainer,\n    }\n\n    if logger:\n        log.info(\"Logging hyperparameters!\")\n        log_hyperparameters(object_dict)\n\n    if cfg.get(\"train\"):\n        log.info(\"Starting training!\")\n        trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get(\"ckpt_path\"))\n\n    train_metrics = trainer.callback_metrics\n\n    if cfg.get(\"test\"):\n        log.info(\"Starting testing!\")\n        ckpt_path = trainer.checkpoint_callback.best_model_path\n        if ckpt_path == \"\":\n            log.warning(\"Best ckpt not found! Using current weights for testing...\")\n            ckpt_path = None\n        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)\n        log.info(f\"Best ckpt path: {ckpt_path}\")\n\n    test_metrics = trainer.callback_metrics\n\n    # merge train and test metrics\n    metric_dict = {**train_metrics, **test_metrics}\n\n    return metric_dict, object_dict\n
    "},{"location":"api/data/mnist_datamodule/","title":"Mnist datamodule","text":"

    MNIST DataModule.

    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule","title":"MNISTDataModule","text":"

    Bases: LightningDataModule

    LightningDataModule for the MNIST dataset.

    The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.

    A LightningDataModule implements 7 key methods:

        def prepare_data(self):\n    # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).\n    # Download data, pre-process, split, save to disk, etc...\n\n    def setup(self, stage):\n    # Things to do on every process in DDP.\n    # Load data, set variables, etc...\n\n    def train_dataloader(self):\n    # return train dataloader\n\n    def val_dataloader(self):\n    # return validation dataloader\n\n    def test_dataloader(self):\n    # return test dataloader\n\n    def predict_dataloader(self):\n    # return predict dataloader\n\n    def teardown(self, stage):\n    # Called on every process in DDP.\n    # Clean up after fit or test.\n

    This allows you to share a full dataset without explaining how to download, split, transform and process the data.

    Read the docs

    https://lightning.ai/docs/pytorch/latest/data/datamodule.html

    Source code in src/data/mnist_datamodule.py
    class MNISTDataModule(LightningDataModule):\n    \"\"\"`LightningDataModule` for the MNIST dataset.\n\n    The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.\n    It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a\n    fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box\n    while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing\n    technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of\n    mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.\n\n    A `LightningDataModule` implements 7 key methods:\n\n    ```python\n        def prepare_data(self):\n        # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).\n        # Download data, pre-process, split, save to disk, etc...\n\n        def setup(self, stage):\n        # Things to do on every process in DDP.\n        # Load data, set variables, etc...\n\n        def train_dataloader(self):\n        # return train dataloader\n\n        def val_dataloader(self):\n        # return validation dataloader\n\n        def test_dataloader(self):\n        # return test dataloader\n\n        def predict_dataloader(self):\n        # return predict dataloader\n\n        def teardown(self, stage):\n        # Called on every process in DDP.\n        # Clean up after fit or test.\n    ```\n\n    This allows you to share a full dataset without explaining how to download,\n    split, transform and process the data.\n\n    Read the docs:\n        https://lightning.ai/docs/pytorch/latest/data/datamodule.html\n    \"\"\"\n\n    def __init__(\n        self,\n        data_dir: str = \"data/\",\n        train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000),\n        batch_size: int = 64,\n        num_workers: int = 0,\n        pin_memory: bool = False,\n    ) -> None:\n        \"\"\"Initialize a `MNISTDataModule`.\n\n        Args:\n            data_dir: The data directory. Defaults to `\"data/\"`.\n            train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.\n            batch_size: The batch size. Defaults to `64`.\n            num_workers: The number of workers. Defaults to `0`.\n            pin_memory: Whether to pin memory. Defaults to `False`.\n        \"\"\"\n        super().__init__()\n\n        # this line allows to access init params with 'self.hparams' attribute\n        # also ensures init params will be stored in ckpt\n        self.save_hyperparameters(logger=False)\n\n        # data transformations\n        self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n\n        self.data_train: Dataset | None = None\n        self.data_val: Dataset | None = None\n        self.data_test: Dataset | None = None\n\n        self.batch_size_per_device = batch_size\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"Get the number of classes.\n\n        :return: The number of MNIST classes (10).\n        \"\"\"\n        return 10\n\n    def prepare_data(self) -> None:\n        \"\"\"Download data if needed.\n\n        Lightning ensures that `self.prepare_data()` is called only\n        within a single process on CPU, so you can safely add your downloading logic within. In\n        case of multi-node training, the execution of this hook depends upon\n        `self.prepare_data_per_node()`.\n\n        Do not use it to assign state (self.x = y).\n        \"\"\"\n        MNIST(self.hparams.data_dir, train=True, download=True)\n        MNIST(self.hparams.data_dir, train=False, download=True)\n\n    def setup(self, stage: str | None = None) -> None:\n        \"\"\"Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.\n\n        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and\n        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after\n        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to\n        `self.setup()` once the data is prepared and available for use.\n\n        Args:\n            stage: The stage to setup. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`. Defaults to ``None``.\n        \"\"\"\n        # Divide batch size by the number of devices.\n        if self.trainer is not None:\n            if self.hparams.batch_size % self.trainer.world_size != 0:\n                raise RuntimeError(  # noqa\n                    f\"Batch size ({self.hparams.batch_size}) \"\n                    \"is not divisible by the number of devices ({self.trainer.world_size}).\"\n                )\n            self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size\n\n        # load and split datasets only if not loaded already\n        if not self.data_train and not self.data_val and not self.data_test:\n            trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)\n            testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)\n            dataset = ConcatDataset(datasets=[trainset, testset])\n            self.data_train, self.data_val, self.data_test = random_split(\n                dataset=dataset,\n                lengths=self.hparams.train_val_test_split,\n                generator=torch.Generator().manual_seed(42),\n            )\n\n    def train_dataloader(self) -> DataLoader[Any]:\n        \"\"\"Create and return the train dataloader.\n\n        Returns:\n            The train dataloader.\n        \"\"\"\n        return DataLoader(\n            dataset=self.data_train,\n            batch_size=self.batch_size_per_device,\n            num_workers=self.hparams.num_workers,\n            pin_memory=self.hparams.pin_memory,\n            shuffle=True,\n        )\n\n    def val_dataloader(self) -> DataLoader[Any]:\n        \"\"\"Create and return the validation dataloader.\n\n        Returns:\n            The validation dataloader.\n        \"\"\"\n        return DataLoader(\n            dataset=self.data_val,\n            batch_size=self.batch_size_per_device,\n            num_workers=self.hparams.num_workers,\n            pin_memory=self.hparams.pin_memory,\n            shuffle=False,\n        )\n\n    def test_dataloader(self) -> DataLoader[Any]:\n        \"\"\"Create and return the test dataloader.\n\n        Returns:\n            The test dataloader.\n        \"\"\"\n        return DataLoader(\n            dataset=self.data_test,\n            batch_size=self.batch_size_per_device,\n            num_workers=self.hparams.num_workers,\n            pin_memory=self.hparams.pin_memory,\n            shuffle=False,\n        )\n\n    def teardown(self, stage: str | None = None) -> None:\n        \"\"\"Lightning hook for cleaning up after trainer main functions.\n\n        `trainer.fit()`, `trainer.validate()`,`trainer.test()`, and `trainer.predict()`.\n\n        Args:\n            stage: The stage being torn down. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n            Defaults to ``None``.\n        \"\"\"\n        pass\n\n    def state_dict(self) -> dict[Any, Any]:\n        \"\"\"Called when saving a checkpoint. Implement to generate and save the datamodule state.\n\n        Returns:\n            A dictionary containing the datamodule state that you want to save.\n        \"\"\"\n        return {}\n\n    def load_state_dict(self, state_dict: dict[str, Any]) -> None:\n        \"\"\"Called when loading a checkpoint. Implement to reload datamodule state given datamodule `state_dict()`.\n\n        Args:\n            state_dict: The datamodule state returned by `self.state_dict()`.\n        \"\"\"\n        pass\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.num_classes","title":"num_classes: int property","text":"

    Get the number of classes.

    :return: The number of MNIST classes (10).

    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.__init__","title":"__init__(data_dir='data/', train_val_test_split=(55000, 5000, 10000), batch_size=64, num_workers=0, pin_memory=False)","text":"

    Initialize a MNISTDataModule.

    Parameters:

    Name Type Description Default data_dir str

    The data directory. Defaults to \"data/\".

    'data/' train_val_test_split tuple[int, int, int]

    The train, validation and test split. Defaults to (55_000, 5_000, 10_000).

    (55000, 5000, 10000) batch_size int

    The batch size. Defaults to 64.

    64 num_workers int

    The number of workers. Defaults to 0.

    0 pin_memory bool

    Whether to pin memory. Defaults to False.

    False Source code in src/data/mnist_datamodule.py
    def __init__(\n    self,\n    data_dir: str = \"data/\",\n    train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000),\n    batch_size: int = 64,\n    num_workers: int = 0,\n    pin_memory: bool = False,\n) -> None:\n    \"\"\"Initialize a `MNISTDataModule`.\n\n    Args:\n        data_dir: The data directory. Defaults to `\"data/\"`.\n        train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.\n        batch_size: The batch size. Defaults to `64`.\n        num_workers: The number of workers. Defaults to `0`.\n        pin_memory: Whether to pin memory. Defaults to `False`.\n    \"\"\"\n    super().__init__()\n\n    # this line allows to access init params with 'self.hparams' attribute\n    # also ensures init params will be stored in ckpt\n    self.save_hyperparameters(logger=False)\n\n    # data transformations\n    self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n\n    self.data_train: Dataset | None = None\n    self.data_val: Dataset | None = None\n    self.data_test: Dataset | None = None\n\n    self.batch_size_per_device = batch_size\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.load_state_dict","title":"load_state_dict(state_dict)","text":"

    Called when loading a checkpoint. Implement to reload datamodule state given datamodule state_dict().

    Parameters:

    Name Type Description Default state_dict dict[str, Any]

    The datamodule state returned by self.state_dict().

    required Source code in src/data/mnist_datamodule.py
    def load_state_dict(self, state_dict: dict[str, Any]) -> None:\n    \"\"\"Called when loading a checkpoint. Implement to reload datamodule state given datamodule `state_dict()`.\n\n    Args:\n        state_dict: The datamodule state returned by `self.state_dict()`.\n    \"\"\"\n    pass\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.prepare_data","title":"prepare_data()","text":"

    Download data if needed.

    Lightning ensures that self.prepare_data() is called only within a single process on CPU, so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook depends upon self.prepare_data_per_node().

    Do not use it to assign state (self.x = y).

    Source code in src/data/mnist_datamodule.py
    def prepare_data(self) -> None:\n    \"\"\"Download data if needed.\n\n    Lightning ensures that `self.prepare_data()` is called only\n    within a single process on CPU, so you can safely add your downloading logic within. In\n    case of multi-node training, the execution of this hook depends upon\n    `self.prepare_data_per_node()`.\n\n    Do not use it to assign state (self.x = y).\n    \"\"\"\n    MNIST(self.hparams.data_dir, train=True, download=True)\n    MNIST(self.hparams.data_dir, train=False, download=True)\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.setup","title":"setup(stage=None)","text":"

    Load data. Set variables: self.data_train, self.data_val, self.data_test.

    This method is called by Lightning before trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict(), so be careful not to execute things like random split twice! Also, it is called after self.prepare_data() and there is a barrier in between which ensures that all the processes proceed to self.setup() once the data is prepared and available for use.

    Parameters:

    Name Type Description Default stage str | None

    The stage to setup. Either \"fit\", \"validate\", \"test\", or \"predict\". Defaults to None.

    None Source code in src/data/mnist_datamodule.py
    def setup(self, stage: str | None = None) -> None:\n    \"\"\"Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.\n\n    This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and\n    `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after\n    `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to\n    `self.setup()` once the data is prepared and available for use.\n\n    Args:\n        stage: The stage to setup. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`. Defaults to ``None``.\n    \"\"\"\n    # Divide batch size by the number of devices.\n    if self.trainer is not None:\n        if self.hparams.batch_size % self.trainer.world_size != 0:\n            raise RuntimeError(  # noqa\n                f\"Batch size ({self.hparams.batch_size}) \"\n                \"is not divisible by the number of devices ({self.trainer.world_size}).\"\n            )\n        self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size\n\n    # load and split datasets only if not loaded already\n    if not self.data_train and not self.data_val and not self.data_test:\n        trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)\n        testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)\n        dataset = ConcatDataset(datasets=[trainset, testset])\n        self.data_train, self.data_val, self.data_test = random_split(\n            dataset=dataset,\n            lengths=self.hparams.train_val_test_split,\n            generator=torch.Generator().manual_seed(42),\n        )\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.state_dict","title":"state_dict()","text":"

    Called when saving a checkpoint. Implement to generate and save the datamodule state.

    Returns:

    Type Description dict[Any, Any]

    A dictionary containing the datamodule state that you want to save.

    Source code in src/data/mnist_datamodule.py
    def state_dict(self) -> dict[Any, Any]:\n    \"\"\"Called when saving a checkpoint. Implement to generate and save the datamodule state.\n\n    Returns:\n        A dictionary containing the datamodule state that you want to save.\n    \"\"\"\n    return {}\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.teardown","title":"teardown(stage=None)","text":"

    Lightning hook for cleaning up after trainer main functions.

    trainer.fit(), trainer.validate(),trainer.test(), and trainer.predict().

    Parameters:

    Name Type Description Default stage str | None

    The stage being torn down. Either \"fit\", \"validate\", \"test\", or \"predict\".

    None Source code in src/data/mnist_datamodule.py
    def teardown(self, stage: str | None = None) -> None:\n    \"\"\"Lightning hook for cleaning up after trainer main functions.\n\n    `trainer.fit()`, `trainer.validate()`,`trainer.test()`, and `trainer.predict()`.\n\n    Args:\n        stage: The stage being torn down. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n        Defaults to ``None``.\n    \"\"\"\n    pass\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.test_dataloader","title":"test_dataloader()","text":"

    Create and return the test dataloader.

    Returns:

    Type Description DataLoader[Any]

    The test dataloader.

    Source code in src/data/mnist_datamodule.py
    def test_dataloader(self) -> DataLoader[Any]:\n    \"\"\"Create and return the test dataloader.\n\n    Returns:\n        The test dataloader.\n    \"\"\"\n    return DataLoader(\n        dataset=self.data_test,\n        batch_size=self.batch_size_per_device,\n        num_workers=self.hparams.num_workers,\n        pin_memory=self.hparams.pin_memory,\n        shuffle=False,\n    )\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.train_dataloader","title":"train_dataloader()","text":"

    Create and return the train dataloader.

    Returns:

    Type Description DataLoader[Any]

    The train dataloader.

    Source code in src/data/mnist_datamodule.py
    def train_dataloader(self) -> DataLoader[Any]:\n    \"\"\"Create and return the train dataloader.\n\n    Returns:\n        The train dataloader.\n    \"\"\"\n    return DataLoader(\n        dataset=self.data_train,\n        batch_size=self.batch_size_per_device,\n        num_workers=self.hparams.num_workers,\n        pin_memory=self.hparams.pin_memory,\n        shuffle=True,\n    )\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.val_dataloader","title":"val_dataloader()","text":"

    Create and return the validation dataloader.

    Returns:

    Type Description DataLoader[Any]

    The validation dataloader.

    Source code in src/data/mnist_datamodule.py
    def val_dataloader(self) -> DataLoader[Any]:\n    \"\"\"Create and return the validation dataloader.\n\n    Returns:\n        The validation dataloader.\n    \"\"\"\n    return DataLoader(\n        dataset=self.data_val,\n        batch_size=self.batch_size_per_device,\n        num_workers=self.hparams.num_workers,\n        pin_memory=self.hparams.pin_memory,\n        shuffle=False,\n    )\n
    "},{"location":"api/data/polars_datamodule/","title":"Polars datamodule","text":"

    PyTorch Lightning DataModule for loading dataset using Polars.

    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule","title":"PolarsDataModule","text":"

    Bases: LightningDataModule

    PyTorch Lightning DataModule for loading dataset using Polars.

    Source code in src/data/polars_datamodule.py
    class PolarsDataModule(LightningDataModule):\n    \"\"\"PyTorch Lightning DataModule for loading dataset using Polars.\"\"\"\n\n    def __init__(\n        self, data_path: str, output_column: str, batch_size: int = 32, num_workers: int = 0, test_size: float = 0.2\n    ) -> None:\n        \"\"\"Initialize the PolarsDataModule.\n\n        Args:\n            data_path: Path to the dataset.\n            output_column: Column name that contains the labels.\n            batch_size: Batch size for the dataloaders.\n            num_workers: Number of workers for the dataloaders.\n            test_size: Fraction of the dataset to be used for validation.\n        \"\"\"\n        super().__init__()\n        self.data_path = data_path\n        self.output_column = output_column\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.test_size = test_size\n        self.df = None  # Will hold the loaded Polars DataFrame\n\n    def setup(self, stage: str = \"\") -> None:\n        \"\"\"Load and split the dataset into train and validation sets.\"\"\"\n        # Load dataset using Polars\n        self.df = pl.read_csv(self.data_path)\n\n        # Split the data into train and validation sets\n        train_df, val_df = train_test_split(self.df, test_size=self.test_size, random_state=42)\n\n        self.train_dataset = PolarsDataset(pl.DataFrame(train_df), output_column=self.output_column)\n        self.val_dataset = PolarsDataset(pl.DataFrame(val_df), output_column=self.output_column)\n\n    def train_dataloader(self) -> DataLoader:\n        \"\"\"Create and return the train dataloader.\"\"\"\n        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)\n\n    def val_dataloader(self) -> DataLoader:\n        \"\"\"Create and return the validation dataloader.\"\"\"\n        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.__init__","title":"__init__(data_path, output_column, batch_size=32, num_workers=0, test_size=0.2)","text":"

    Initialize the PolarsDataModule.

    Parameters:

    Name Type Description Default data_path str

    Path to the dataset.

    required output_column str

    Column name that contains the labels.

    required batch_size int

    Batch size for the dataloaders.

    32 num_workers int

    Number of workers for the dataloaders.

    0 test_size float

    Fraction of the dataset to be used for validation.

    0.2 Source code in src/data/polars_datamodule.py
    def __init__(\n    self, data_path: str, output_column: str, batch_size: int = 32, num_workers: int = 0, test_size: float = 0.2\n) -> None:\n    \"\"\"Initialize the PolarsDataModule.\n\n    Args:\n        data_path: Path to the dataset.\n        output_column: Column name that contains the labels.\n        batch_size: Batch size for the dataloaders.\n        num_workers: Number of workers for the dataloaders.\n        test_size: Fraction of the dataset to be used for validation.\n    \"\"\"\n    super().__init__()\n    self.data_path = data_path\n    self.output_column = output_column\n    self.batch_size = batch_size\n    self.num_workers = num_workers\n    self.test_size = test_size\n    self.df = None  # Will hold the loaded Polars DataFrame\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.setup","title":"setup(stage='')","text":"

    Load and split the dataset into train and validation sets.

    Source code in src/data/polars_datamodule.py
    def setup(self, stage: str = \"\") -> None:\n    \"\"\"Load and split the dataset into train and validation sets.\"\"\"\n    # Load dataset using Polars\n    self.df = pl.read_csv(self.data_path)\n\n    # Split the data into train and validation sets\n    train_df, val_df = train_test_split(self.df, test_size=self.test_size, random_state=42)\n\n    self.train_dataset = PolarsDataset(pl.DataFrame(train_df), output_column=self.output_column)\n    self.val_dataset = PolarsDataset(pl.DataFrame(val_df), output_column=self.output_column)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.train_dataloader","title":"train_dataloader()","text":"

    Create and return the train dataloader.

    Source code in src/data/polars_datamodule.py
    def train_dataloader(self) -> DataLoader:\n    \"\"\"Create and return the train dataloader.\"\"\"\n    return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.val_dataloader","title":"val_dataloader()","text":"

    Create and return the validation dataloader.

    Source code in src/data/polars_datamodule.py
    def val_dataloader(self) -> DataLoader:\n    \"\"\"Create and return the validation dataloader.\"\"\"\n    return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset","title":"PolarsDataset","text":"

    Bases: Dataset

    Custom PyTorch Dataset wrapping a Polars DataFrame.

    Source code in src/data/polars_datamodule.py
    class PolarsDataset(Dataset):\n    \"\"\"Custom PyTorch Dataset wrapping a Polars DataFrame.\"\"\"\n\n    def __init__(self, df: pl.DataFrame, output_column: str) -> None:\n        \"\"\"Initialize the PolarsDataset.\"\"\"\n        self.df = df\n        self.output_column = output_column\n\n    def __len__(self) -> int:\n        \"\"\"Return the number of rows in the dataset.\"\"\"\n        return self.df.shape[0]\n\n    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Return the features and label for the given index.\"\"\"\n        row = self.df[idx]\n        features = torch.tensor([val for col, val in row.items() if col != self.output_column], dtype=torch.float32)\n        label = torch.tensor(row[self.output_column], dtype=torch.long)\n        return features, label\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset.__getitem__","title":"__getitem__(idx)","text":"

    Return the features and label for the given index.

    Source code in src/data/polars_datamodule.py
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Return the features and label for the given index.\"\"\"\n    row = self.df[idx]\n    features = torch.tensor([val for col, val in row.items() if col != self.output_column], dtype=torch.float32)\n    label = torch.tensor(row[self.output_column], dtype=torch.long)\n    return features, label\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset.__init__","title":"__init__(df, output_column)","text":"

    Initialize the PolarsDataset.

    Source code in src/data/polars_datamodule.py
    def __init__(self, df: pl.DataFrame, output_column: str) -> None:\n    \"\"\"Initialize the PolarsDataset.\"\"\"\n    self.df = df\n    self.output_column = output_column\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset.__len__","title":"__len__()","text":"

    Return the number of rows in the dataset.

    Source code in src/data/polars_datamodule.py
    def __len__(self) -> int:\n    \"\"\"Return the number of rows in the dataset.\"\"\"\n    return self.df.shape[0]\n
    "},{"location":"api/models/mnist_module/","title":"Mnist module","text":"

    Mnist simple model.

    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule","title":"MNISTLitModule","text":"

    Bases: LightningModule

    Example of a LightningModule for MNIST classification.

    A LightningModule implements 8 key methods:

    def __init__(self):\n# Define initialization code here.\n\ndef setup(self, stage):\n# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.\n# This hook is called on every process when using DDP.\n\ndef training_step(self, batch, batch_idx):\n# The complete training step.\n\ndef validation_step(self, batch, batch_idx):\n# The complete validation step.\n\ndef test_step(self, batch, batch_idx):\n# The complete test step.\n\ndef predict_step(self, batch, batch_idx):\n# The complete predict step.\n\ndef configure_optimizers(self):\n# Define and configure optimizers and LR schedulers.\n
    Docs

    https://lightning.ai/docs/pytorch/latest/common/lightning_module.html

    Source code in src/models/mnist_module.py
    class MNISTLitModule(LightningModule):\n    \"\"\"Example of a `LightningModule` for MNIST classification.\n\n    A `LightningModule` implements 8 key methods:\n\n    ```python\n    def __init__(self):\n    # Define initialization code here.\n\n    def setup(self, stage):\n    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.\n    # This hook is called on every process when using DDP.\n\n    def training_step(self, batch, batch_idx):\n    # The complete training step.\n\n    def validation_step(self, batch, batch_idx):\n    # The complete validation step.\n\n    def test_step(self, batch, batch_idx):\n    # The complete test step.\n\n    def predict_step(self, batch, batch_idx):\n    # The complete predict step.\n\n    def configure_optimizers(self):\n    # Define and configure optimizers and LR schedulers.\n    ```\n\n    Docs:\n        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html\n    \"\"\"\n\n    def __init__(\n        self,\n        net: torch.nn.Module,\n        optimizer: torch.optim.Optimizer,\n        scheduler: torch.optim.lr_scheduler,\n        compile_model: bool,\n    ) -> None:\n        \"\"\"Initialize a `MNISTLitModule`.\n\n        Args:\n            net: The model to train.\n            optimizer: The optimizer to use for training.\n            scheduler: The learning rate scheduler to use for training.\n            compile_model: Whether or not compile the model.\n        \"\"\"\n        super().__init__()\n\n        # this line allows to access init params with 'self.hparams' attribute\n        # also ensures init params will be stored in ckpt\n        self.save_hyperparameters(logger=False)\n\n        self.net = net\n\n        # loss function\n        self.criterion = torch.nn.CrossEntropyLoss()\n\n        # metric objects for calculating and averaging accuracy across batches\n        self.train_acc = Accuracy(task=\"multiclass\", num_classes=10)\n        self.val_acc = Accuracy(task=\"multiclass\", num_classes=10)\n        self.test_acc = Accuracy(task=\"multiclass\", num_classes=10)\n\n        # for averaging loss across batches\n        self.train_loss = MeanMetric()\n        self.val_loss = MeanMetric()\n        self.test_loss = MeanMetric()\n\n        # for tracking best so far validation accuracy\n        self.val_acc_best = MaxMetric()\n\n    @typechecked\n    def forward(self, x: TensorType[\"batch\", 1, 28, 28]) -> TensorType[\"batch\", 10]:  # noqa\n        \"\"\"Perform a forward pass through the model.\n\n        Args:\n            x: A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.\n\n        Returns:\n            A tensor of shape (batch_size, 10) representing the logits for each class.\n        \"\"\"\n        return self.net(x)\n\n    def on_train_start(self) -> None:\n        \"\"\"Lightning hook that is called when training begins.\"\"\"\n        # by default lightning executes validation step sanity checks before training starts,\n        # so it's worth to make sure validation metrics don't store results from these checks\n        self.val_loss.reset()\n        self.val_acc.reset()\n        self.val_acc_best.reset()\n\n    @typechecked\n    def model_step(self, x: TensorType[\"batch\", 1, 28, 28], y: TensorType[\"batch\"]):  # noqa\n        \"\"\"Perform a single model step.\n\n        Args:\n            x: Tensor of shape [batch, 1, 28, 28] representing the images.\n            y: Tensor of shape [batch] representing the classes.\n\n        Returns:\n            A tuple containing:\n                - loss: A tensor of shape (batch_size,)\n                - preds: A tensor of predicted class indices (batch_size,)\n                - targets: A tensor of true class labels (batch_size,)\n        \"\"\"\n        logits = self.forward(x)\n        loss = self.criterion(logits, y)\n        preds = torch.argmax(logits, dim=1)\n        return loss, preds, y\n\n    @typechecked\n    def training_step(self, batch: Any) -> TensorType[()]:\n        \"\"\"Perform a single training step.\n\n        Args:\n            batch: A tuple containing input images and target labels.\n            batch_idx: The index of the current batch.\n\n        Returns:\n            A scalar loss tensor.\n        \"\"\"\n        x, y = batch\n        loss, preds, targets = self.model_step(x, y)\n        self.train_loss(loss)\n        self.train_acc(preds, targets)\n        self.log(\"train/loss\", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)\n        self.log(\"train/acc\", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)\n        return loss\n\n    def on_train_epoch_end(self) -> None:\n        \"\"\"Lightning hook that is called when a training epoch ends.\"\"\"\n        pass\n\n    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n        \"\"\"Perform a single validation step on a batch of data from the validation set.\n\n        Args:\n            batch: A batch of data (a tuple) containing the input tensor of images and target\n                labels.\n            batch_idx: The index of the current batch.\n        \"\"\"\n        x, y = batch\n        loss, preds, targets = self.model_step(x, y)\n\n        # update and log metrics\n        self.val_loss(loss)\n        self.val_acc(preds, targets)\n        self.log(\"val/loss\", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)\n        self.log(\"val/acc\", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)\n\n    def on_validation_epoch_end(self) -> None:\n        \"\"\"Lightning hook that is called when a validation epoch ends.\"\"\"\n        acc = self.val_acc.compute()  # get current val acc\n        self.val_acc_best(acc)  # update best so far val acc\n        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object\n        # otherwise metric would be reset by lightning after each epoch\n        self.log(\"val/acc_best\", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)\n\n    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n        \"\"\"Perform a single test step on a batch of data from the test set.\n\n        Args:\n            batch: A batch of data (a tuple) containing the input tensor of images and target\n                labels.\n            batch_idx: The index of the current batch.\n        \"\"\"\n        x, y = batch\n        loss, preds, targets = self.model_step(x, y)\n\n        # update and log metrics\n        self.test_loss(loss)\n        self.test_acc(preds, targets)\n        self.log(\"test/loss\", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)\n        self.log(\"test/acc\", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)\n\n    def on_test_epoch_end(self) -> None:\n        \"\"\"Lightning hook that is called when a test epoch ends.\"\"\"\n        pass\n\n    def setup(self, stage: str) -> None:\n        \"\"\"Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.\n\n        This is a good hook when you need to build models dynamically or adjust something about\n        them. This hook is called on every process when using DDP.\n\n        Args:\n            stage: Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n        \"\"\"\n        if self.hparams.compile_model and stage == \"fit\":\n            self.net = torch.compile(self.net)\n\n    def configure_optimizers(self) -> dict[str, Any]:\n        \"\"\"Choose what optimizers and learning-rate schedulers to use in your optimization.\n\n        Normally you'd need one. But in the case of GANs or similar you might have multiple.\n\n        Examples:\n            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers\n\n        Returns:\n            A dict containing the configured optimizers and learning-rate schedulers to be used for training.\n        \"\"\"\n        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())\n        if self.hparams.scheduler is not None:\n            scheduler = self.hparams.scheduler(optimizer=optimizer)\n            return {\n                \"optimizer\": optimizer,\n                \"lr_scheduler\": {\n                    \"scheduler\": scheduler,\n                    \"monitor\": \"val/loss\",\n                    \"interval\": \"epoch\",\n                    \"frequency\": 1,\n                },\n            }\n        return {\"optimizer\": optimizer}\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.__init__","title":"__init__(net, optimizer, scheduler, compile_model)","text":"

    Initialize a MNISTLitModule.

    Parameters:

    Name Type Description Default net Module

    The model to train.

    required optimizer Optimizer

    The optimizer to use for training.

    required scheduler lr_scheduler

    The learning rate scheduler to use for training.

    required compile_model bool

    Whether or not compile the model.

    required Source code in src/models/mnist_module.py
    def __init__(\n    self,\n    net: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    scheduler: torch.optim.lr_scheduler,\n    compile_model: bool,\n) -> None:\n    \"\"\"Initialize a `MNISTLitModule`.\n\n    Args:\n        net: The model to train.\n        optimizer: The optimizer to use for training.\n        scheduler: The learning rate scheduler to use for training.\n        compile_model: Whether or not compile the model.\n    \"\"\"\n    super().__init__()\n\n    # this line allows to access init params with 'self.hparams' attribute\n    # also ensures init params will be stored in ckpt\n    self.save_hyperparameters(logger=False)\n\n    self.net = net\n\n    # loss function\n    self.criterion = torch.nn.CrossEntropyLoss()\n\n    # metric objects for calculating and averaging accuracy across batches\n    self.train_acc = Accuracy(task=\"multiclass\", num_classes=10)\n    self.val_acc = Accuracy(task=\"multiclass\", num_classes=10)\n    self.test_acc = Accuracy(task=\"multiclass\", num_classes=10)\n\n    # for averaging loss across batches\n    self.train_loss = MeanMetric()\n    self.val_loss = MeanMetric()\n    self.test_loss = MeanMetric()\n\n    # for tracking best so far validation accuracy\n    self.val_acc_best = MaxMetric()\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.configure_optimizers","title":"configure_optimizers()","text":"

    Choose what optimizers and learning-rate schedulers to use in your optimization.

    Normally you'd need one. But in the case of GANs or similar you might have multiple.

    Examples:

    https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

    Returns:

    Type Description dict[str, Any]

    A dict containing the configured optimizers and learning-rate schedulers to be used for training.

    Source code in src/models/mnist_module.py
    def configure_optimizers(self) -> dict[str, Any]:\n    \"\"\"Choose what optimizers and learning-rate schedulers to use in your optimization.\n\n    Normally you'd need one. But in the case of GANs or similar you might have multiple.\n\n    Examples:\n        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers\n\n    Returns:\n        A dict containing the configured optimizers and learning-rate schedulers to be used for training.\n    \"\"\"\n    optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())\n    if self.hparams.scheduler is not None:\n        scheduler = self.hparams.scheduler(optimizer=optimizer)\n        return {\n            \"optimizer\": optimizer,\n            \"lr_scheduler\": {\n                \"scheduler\": scheduler,\n                \"monitor\": \"val/loss\",\n                \"interval\": \"epoch\",\n                \"frequency\": 1,\n            },\n        }\n    return {\"optimizer\": optimizer}\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.forward","title":"forward(x)","text":"

    Perform a forward pass through the model.

    Parameters:

    Name Type Description Default x TensorType[batch, 1, 28, 28]

    A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.

    required

    Returns:

    Type Description TensorType[batch, 10]

    A tensor of shape (batch_size, 10) representing the logits for each class.

    Source code in src/models/mnist_module.py
    @typechecked\ndef forward(self, x: TensorType[\"batch\", 1, 28, 28]) -> TensorType[\"batch\", 10]:  # noqa\n    \"\"\"Perform a forward pass through the model.\n\n    Args:\n        x: A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.\n\n    Returns:\n        A tensor of shape (batch_size, 10) representing the logits for each class.\n    \"\"\"\n    return self.net(x)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.model_step","title":"model_step(x, y)","text":"

    Perform a single model step.

    Parameters:

    Name Type Description Default x TensorType[batch, 1, 28, 28]

    Tensor of shape [batch, 1, 28, 28] representing the images.

    required y TensorType[batch]

    Tensor of shape [batch] representing the classes.

    required

    Returns:

    Type Description

    A tuple containing: - loss: A tensor of shape (batch_size,) - preds: A tensor of predicted class indices (batch_size,) - targets: A tensor of true class labels (batch_size,)

    Source code in src/models/mnist_module.py
    @typechecked\ndef model_step(self, x: TensorType[\"batch\", 1, 28, 28], y: TensorType[\"batch\"]):  # noqa\n    \"\"\"Perform a single model step.\n\n    Args:\n        x: Tensor of shape [batch, 1, 28, 28] representing the images.\n        y: Tensor of shape [batch] representing the classes.\n\n    Returns:\n        A tuple containing:\n            - loss: A tensor of shape (batch_size,)\n            - preds: A tensor of predicted class indices (batch_size,)\n            - targets: A tensor of true class labels (batch_size,)\n    \"\"\"\n    logits = self.forward(x)\n    loss = self.criterion(logits, y)\n    preds = torch.argmax(logits, dim=1)\n    return loss, preds, y\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_test_epoch_end","title":"on_test_epoch_end()","text":"

    Lightning hook that is called when a test epoch ends.

    Source code in src/models/mnist_module.py
    def on_test_epoch_end(self) -> None:\n    \"\"\"Lightning hook that is called when a test epoch ends.\"\"\"\n    pass\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_train_epoch_end","title":"on_train_epoch_end()","text":"

    Lightning hook that is called when a training epoch ends.

    Source code in src/models/mnist_module.py
    def on_train_epoch_end(self) -> None:\n    \"\"\"Lightning hook that is called when a training epoch ends.\"\"\"\n    pass\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_train_start","title":"on_train_start()","text":"

    Lightning hook that is called when training begins.

    Source code in src/models/mnist_module.py
    def on_train_start(self) -> None:\n    \"\"\"Lightning hook that is called when training begins.\"\"\"\n    # by default lightning executes validation step sanity checks before training starts,\n    # so it's worth to make sure validation metrics don't store results from these checks\n    self.val_loss.reset()\n    self.val_acc.reset()\n    self.val_acc_best.reset()\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_validation_epoch_end","title":"on_validation_epoch_end()","text":"

    Lightning hook that is called when a validation epoch ends.

    Source code in src/models/mnist_module.py
    def on_validation_epoch_end(self) -> None:\n    \"\"\"Lightning hook that is called when a validation epoch ends.\"\"\"\n    acc = self.val_acc.compute()  # get current val acc\n    self.val_acc_best(acc)  # update best so far val acc\n    # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object\n    # otherwise metric would be reset by lightning after each epoch\n    self.log(\"val/acc_best\", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.setup","title":"setup(stage)","text":"

    Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

    This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

    Parameters:

    Name Type Description Default stage str

    Either \"fit\", \"validate\", \"test\", or \"predict\".

    required Source code in src/models/mnist_module.py
    def setup(self, stage: str) -> None:\n    \"\"\"Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.\n\n    This is a good hook when you need to build models dynamically or adjust something about\n    them. This hook is called on every process when using DDP.\n\n    Args:\n        stage: Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n    \"\"\"\n    if self.hparams.compile_model and stage == \"fit\":\n        self.net = torch.compile(self.net)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.test_step","title":"test_step(batch, batch_idx)","text":"

    Perform a single test step on a batch of data from the test set.

    Parameters:

    Name Type Description Default batch tuple[Tensor, Tensor]

    A batch of data (a tuple) containing the input tensor of images and target labels.

    required batch_idx int

    The index of the current batch.

    required Source code in src/models/mnist_module.py
    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n    \"\"\"Perform a single test step on a batch of data from the test set.\n\n    Args:\n        batch: A batch of data (a tuple) containing the input tensor of images and target\n            labels.\n        batch_idx: The index of the current batch.\n    \"\"\"\n    x, y = batch\n    loss, preds, targets = self.model_step(x, y)\n\n    # update and log metrics\n    self.test_loss(loss)\n    self.test_acc(preds, targets)\n    self.log(\"test/loss\", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)\n    self.log(\"test/acc\", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.training_step","title":"training_step(batch)","text":"

    Perform a single training step.

    Parameters:

    Name Type Description Default batch Any

    A tuple containing input images and target labels.

    required batch_idx

    The index of the current batch.

    required

    Returns:

    Type Description TensorType[]

    A scalar loss tensor.

    Source code in src/models/mnist_module.py
    @typechecked\ndef training_step(self, batch: Any) -> TensorType[()]:\n    \"\"\"Perform a single training step.\n\n    Args:\n        batch: A tuple containing input images and target labels.\n        batch_idx: The index of the current batch.\n\n    Returns:\n        A scalar loss tensor.\n    \"\"\"\n    x, y = batch\n    loss, preds, targets = self.model_step(x, y)\n    self.train_loss(loss)\n    self.train_acc(preds, targets)\n    self.log(\"train/loss\", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)\n    self.log(\"train/acc\", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)\n    return loss\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.validation_step","title":"validation_step(batch, batch_idx)","text":"

    Perform a single validation step on a batch of data from the validation set.

    Parameters:

    Name Type Description Default batch tuple[Tensor, Tensor]

    A batch of data (a tuple) containing the input tensor of images and target labels.

    required batch_idx int

    The index of the current batch.

    required Source code in src/models/mnist_module.py
    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n    \"\"\"Perform a single validation step on a batch of data from the validation set.\n\n    Args:\n        batch: A batch of data (a tuple) containing the input tensor of images and target\n            labels.\n        batch_idx: The index of the current batch.\n    \"\"\"\n    x, y = batch\n    loss, preds, targets = self.model_step(x, y)\n\n    # update and log metrics\n    self.val_loss(loss)\n    self.val_acc(preds, targets)\n    self.log(\"val/loss\", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)\n    self.log(\"val/acc\", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)\n
    "},{"location":"api/models/components/simple_dense_net/","title":"Simple dense net","text":"

    Simple dense neural network.

    "},{"location":"api/models/components/simple_dense_net/#src.models.components.simple_dense_net.SimpleDenseNet","title":"SimpleDenseNet","text":"

    Bases: Module

    A simple fully-connected neural net for computing predictions.

    Source code in src/models/components/simple_dense_net.py
    class SimpleDenseNet(nn.Module):\n    \"\"\"A simple fully-connected neural net for computing predictions.\"\"\"\n\n    def __init__(\n        self,\n        input_size: int = 784,\n        lin1_size: int = 256,\n        lin2_size: int = 256,\n        lin3_size: int = 256,\n        output_size: int = 10,\n    ) -> None:\n        \"\"\"Initialize a `SimpleDenseNet` module.\n\n        Args:\n            input_size: The number of input features.\n            lin1_size: The number of output features of the first linear layer.\n            lin2_size: The number of output features of the second linear layer.\n            lin3_size: The number of output features of the third linear layer.\n            output_size: The number of output features of the final linear layer.\n        \"\"\"\n        super().__init__()\n\n        self.model = nn.Sequential(\n            nn.Linear(input_size, lin1_size),\n            nn.BatchNorm1d(lin1_size),\n            nn.ReLU(),\n            nn.Linear(lin1_size, lin2_size),\n            nn.BatchNorm1d(lin2_size),\n            nn.ReLU(),\n            nn.Linear(lin2_size, lin3_size),\n            nn.BatchNorm1d(lin3_size),\n            nn.ReLU(),\n            nn.Linear(lin3_size, output_size),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Perform a single forward pass through the network.\n\n        Args:\n            x: The input tensor.\n\n        Returns:\n            A tensor of predictions.\n        \"\"\"\n        batch_size, channels, width, height = x.size()\n\n        # (batch, 1, width, height) -> (batch, 1*width*height)\n        x = x.view(batch_size, -1)\n\n        return self.model(x)\n
    "},{"location":"api/models/components/simple_dense_net/#src.models.components.simple_dense_net.SimpleDenseNet.__init__","title":"__init__(input_size=784, lin1_size=256, lin2_size=256, lin3_size=256, output_size=10)","text":"

    Initialize a SimpleDenseNet module.

    Parameters:

    Name Type Description Default input_size int

    The number of input features.

    784 lin1_size int

    The number of output features of the first linear layer.

    256 lin2_size int

    The number of output features of the second linear layer.

    256 lin3_size int

    The number of output features of the third linear layer.

    256 output_size int

    The number of output features of the final linear layer.

    10 Source code in src/models/components/simple_dense_net.py
    def __init__(\n    self,\n    input_size: int = 784,\n    lin1_size: int = 256,\n    lin2_size: int = 256,\n    lin3_size: int = 256,\n    output_size: int = 10,\n) -> None:\n    \"\"\"Initialize a `SimpleDenseNet` module.\n\n    Args:\n        input_size: The number of input features.\n        lin1_size: The number of output features of the first linear layer.\n        lin2_size: The number of output features of the second linear layer.\n        lin3_size: The number of output features of the third linear layer.\n        output_size: The number of output features of the final linear layer.\n    \"\"\"\n    super().__init__()\n\n    self.model = nn.Sequential(\n        nn.Linear(input_size, lin1_size),\n        nn.BatchNorm1d(lin1_size),\n        nn.ReLU(),\n        nn.Linear(lin1_size, lin2_size),\n        nn.BatchNorm1d(lin2_size),\n        nn.ReLU(),\n        nn.Linear(lin2_size, lin3_size),\n        nn.BatchNorm1d(lin3_size),\n        nn.ReLU(),\n        nn.Linear(lin3_size, output_size),\n    )\n
    "},{"location":"api/models/components/simple_dense_net/#src.models.components.simple_dense_net.SimpleDenseNet.forward","title":"forward(x)","text":"

    Perform a single forward pass through the network.

    Parameters:

    Name Type Description Default x Tensor

    The input tensor.

    required

    Returns:

    Type Description Tensor

    A tensor of predictions.

    Source code in src/models/components/simple_dense_net.py
    def forward(self, x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Perform a single forward pass through the network.\n\n    Args:\n        x: The input tensor.\n\n    Returns:\n        A tensor of predictions.\n    \"\"\"\n    batch_size, channels, width, height = x.size()\n\n    # (batch, 1, width, height) -> (batch, 1*width*height)\n    x = x.view(batch_size, -1)\n\n    return self.model(x)\n
    "},{"location":"api/serve_apis/mnist_serve/","title":"Mnist serve","text":"

    This is an example of a LitServe api for the Mnist LightningModule.

    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI","title":"MNISTServeAPI","text":"

    Bases: LitAPI

    LitServe API for serving the MNIST model.

    Source code in src/serve_apis/mnist_serve.py
    class MNISTServeAPI(ls.LitAPI):\n    \"\"\"LitServe API for serving the MNIST model.\"\"\"\n\n    def __init__(self, model_class: lightning.pytorch.LightningModule, checkpoint_path: str):\n        \"\"\"Initialize the MNISTServeAPI.\n\n        Args:\n            model_class: The LightningModule class to serve.\n            checkpoint_path: The path to the model checkpoint.\n        \"\"\"\n        self.checkpoint_path = checkpoint_path\n        self.model_class = model_class\n\n    def setup(self, device: str):\n        \"\"\"Setup is called once at startup.\n\n        Load the model, set the device, and prepare any other necessary components.\n        \"\"\"\n        # Load the trained MNIST model (ensure model weights are loaded properly here)\n        self.model = self.model_class.load_from_checkpoint(self.checkpoint_path)\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        self.model.to(device)  # Move the model to the appropriate device (CPU or GPU)\n        self.model.eval()  # Set the model to evaluation mode\n\n        # Define transforms that match the training data processing pipeline\n        self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n\n    def decode_request(self, request: dict):\n        \"\"\"Decode the incoming request and prepare the input for the model.\"\"\"\n        # Convert the request payload into a tensor for model input\n        image_data = request[\"image\"]\n        # Ensure that the image is a tensor of shape [1, 28, 28] (MNIST image dimensions)\n        image_tensor = torch.tensor(image_data).unsqueeze(0)  # Add a batch dimension\n        return self.transforms(image_tensor)  # Apply the necessary transformations\n\n    def predict(self, x: torch.Tensor):\n        \"\"\"Run inference using the MNIST model and return the prediction.\"\"\"\n        # Forward pass through the model\n        with torch.no_grad():\n            logits = self.model(x.unsqueeze(0))  # Add batch dimension for inference\n            preds = torch.argmax(logits, dim=1)  # Get the predicted class\n        return {\"prediction\": preds.item()}  # Return the prediction as a dictionary\n\n    def encode_response(self, output: dict):\n        \"\"\"Encode the model's output into a response payload.\"\"\"\n        # Simply pass the output as the response\n        return {\"predicted_class\": output[\"prediction\"]}\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.__init__","title":"__init__(model_class, checkpoint_path)","text":"

    Initialize the MNISTServeAPI.

    Parameters:

    Name Type Description Default model_class LightningModule

    The LightningModule class to serve.

    required checkpoint_path str

    The path to the model checkpoint.

    required Source code in src/serve_apis/mnist_serve.py
    def __init__(self, model_class: lightning.pytorch.LightningModule, checkpoint_path: str):\n    \"\"\"Initialize the MNISTServeAPI.\n\n    Args:\n        model_class: The LightningModule class to serve.\n        checkpoint_path: The path to the model checkpoint.\n    \"\"\"\n    self.checkpoint_path = checkpoint_path\n    self.model_class = model_class\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.decode_request","title":"decode_request(request)","text":"

    Decode the incoming request and prepare the input for the model.

    Source code in src/serve_apis/mnist_serve.py
    def decode_request(self, request: dict):\n    \"\"\"Decode the incoming request and prepare the input for the model.\"\"\"\n    # Convert the request payload into a tensor for model input\n    image_data = request[\"image\"]\n    # Ensure that the image is a tensor of shape [1, 28, 28] (MNIST image dimensions)\n    image_tensor = torch.tensor(image_data).unsqueeze(0)  # Add a batch dimension\n    return self.transforms(image_tensor)  # Apply the necessary transformations\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.encode_response","title":"encode_response(output)","text":"

    Encode the model's output into a response payload.

    Source code in src/serve_apis/mnist_serve.py
    def encode_response(self, output: dict):\n    \"\"\"Encode the model's output into a response payload.\"\"\"\n    # Simply pass the output as the response\n    return {\"predicted_class\": output[\"prediction\"]}\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.predict","title":"predict(x)","text":"

    Run inference using the MNIST model and return the prediction.

    Source code in src/serve_apis/mnist_serve.py
    def predict(self, x: torch.Tensor):\n    \"\"\"Run inference using the MNIST model and return the prediction.\"\"\"\n    # Forward pass through the model\n    with torch.no_grad():\n        logits = self.model(x.unsqueeze(0))  # Add batch dimension for inference\n        preds = torch.argmax(logits, dim=1)  # Get the predicted class\n    return {\"prediction\": preds.item()}  # Return the prediction as a dictionary\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.setup","title":"setup(device)","text":"

    Setup is called once at startup.

    Load the model, set the device, and prepare any other necessary components.

    Source code in src/serve_apis/mnist_serve.py
    def setup(self, device: str):\n    \"\"\"Setup is called once at startup.\n\n    Load the model, set the device, and prepare any other necessary components.\n    \"\"\"\n    # Load the trained MNIST model (ensure model weights are loaded properly here)\n    self.model = self.model_class.load_from_checkpoint(self.checkpoint_path)\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    self.model.to(device)  # Move the model to the appropriate device (CPU or GPU)\n    self.model.eval()  # Set the model to evaluation mode\n\n    # Define transforms that match the training data processing pipeline\n    self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n
    "},{"location":"api/utils/download_utils/","title":"Download utils","text":"

    Utility functions aimed at downloading any data from external sources.

    "},{"location":"api/utils/download_utils/#src.utils.download_utils.download_cloud_directory","title":"download_cloud_directory(cloud_directory, output_folder, cloud='gs')","text":"

    Download a given cloud directory.

    Parameters:

    Name Type Description Default cloud_directory str

    for example gs://bucket-name/path/to/directory

    required output_folder str

    where the data downloaded will be stored (ideally data/ folder)

    required cloud str

    the cloud provider, currently only \"gs\" is supported

    'gs' Source code in src/utils/download_utils.py
    def download_cloud_directory(cloud_directory: str, output_folder: str, cloud: str = \"gs\") -> None:\n    \"\"\"Download a given cloud directory.\n\n    Args:\n        cloud_directory: for example gs://bucket-name/path/to/directory\n        output_folder: where the data downloaded will be stored (ideally data/ folder)\n        cloud: the cloud provider, currently only \"gs\" is supported\n    \"\"\"\n    cloudpathlib.Path(cloud_directory).download_to(output_folder)\n
    "},{"location":"api/utils/download_utils/#src.utils.download_utils.download_kaggle_dataset","title":"download_kaggle_dataset(dataset_name, output_folder)","text":"

    Download a given Kaggle dataset.

    Parameters:

    Name Type Description Default dataset_name str

    for example googleai/pfam-seed-random-split

    required output_folder str

    where the data downloaded will be stored (ideally data/ folder)

    required Source code in src/utils/download_utils.py
    def download_kaggle_dataset(dataset_name: str, output_folder: str) -> None:\n    \"\"\"Download a given Kaggle dataset.\n\n    Args:\n        dataset_name: for example googleai/pfam-seed-random-split\n        output_folder: where the data downloaded will be stored (ideally data/ folder)\n    \"\"\"\n    from kaggle.api.kaggle_api_extended import KaggleApi\n\n    api = KaggleApi()\n    log.info(\"Authenticating to Kaggle API\")\n    api.authenticate()\n    log.info(\"Downloading dataset\")\n    api.dataset_download_files(dataset_name, path=output_folder, unzip=True, quiet=False)\n    log.info(\"Download successful\")\n
    "},{"location":"api/utils/instantiators/","title":"Instantiators","text":"

    Module to instantiate different objects types.

    "},{"location":"api/utils/instantiators/#src.utils.instantiators.instantiate_callbacks","title":"instantiate_callbacks(callbacks_cfg)","text":"

    Instantiates callbacks from config.

    :param callbacks_cfg: A DictConfig object containing callback configurations. :return: A list of instantiated callbacks.

    Source code in src/utils/instantiators.py
    def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]:\n    \"\"\"Instantiates callbacks from config.\n\n    :param callbacks_cfg: A DictConfig object containing callback configurations.\n    :return: A list of instantiated callbacks.\n    \"\"\"\n    callbacks: list[Callback] = []\n\n    if not callbacks_cfg:\n        log.warning(\"No callback configs found! Skipping..\")\n        return callbacks\n\n    if not isinstance(callbacks_cfg, DictConfig):\n        raise TypeError(\"Callbacks config must be a DictConfig!\")  # noqa: TRY003\n\n    for _, cb_conf in callbacks_cfg.items():\n        if isinstance(cb_conf, DictConfig) and \"_target_\" in cb_conf:\n            log.info(f\"Instantiating callback <{cb_conf._target_}>\")\n            callbacks.append(hydra.utils.instantiate(cb_conf))\n\n    return callbacks\n
    "},{"location":"api/utils/instantiators/#src.utils.instantiators.instantiate_loggers","title":"instantiate_loggers(logger_cfg)","text":"

    Instantiates loggers from config.

    :param logger_cfg: A DictConfig object containing logger configurations. :return: A list of instantiated loggers.

    Source code in src/utils/instantiators.py
    def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]:\n    \"\"\"Instantiates loggers from config.\n\n    :param logger_cfg: A DictConfig object containing logger configurations.\n    :return: A list of instantiated loggers.\n    \"\"\"\n    logger: list[Logger] = []\n\n    if not logger_cfg:\n        log.warning(\"No logger configs found! Skipping...\")\n        return logger\n\n    if not isinstance(logger_cfg, DictConfig):\n        raise TypeError(\"Logger config must be a DictConfig!\")  # noqa: TRY003\n\n    for _, lg_conf in logger_cfg.items():\n        if isinstance(lg_conf, DictConfig) and \"_target_\" in lg_conf:\n            log.info(f\"Instantiating logger <{lg_conf._target_}>\")\n            logger.append(hydra.utils.instantiate(lg_conf))\n\n    return logger\n
    "},{"location":"api/utils/logging_utils/","title":"Logging utils","text":"

    Logging utility instantiator.

    "},{"location":"api/utils/logging_utils/#src.utils.logging_utils.log_hyperparameters","title":"log_hyperparameters(object_dict)","text":"

    Controls which config parts are saved by Lightning loggers.

    Additionally saves number of model parameters

    Parameters:

    Name Type Description Default object_dict dict[str, Any]

    A dictionary containing the following objects: cfg, model, trainer.

    required Source code in src/utils/logging_utils.py
    @rank_zero_only\ndef log_hyperparameters(object_dict: dict[str, Any]) -> None:\n    \"\"\"Controls which config parts are saved by Lightning loggers.\n\n    Additionally saves number of model parameters\n\n    Args:\n        object_dict: A dictionary containing the following objects: cfg, model, trainer.\n    \"\"\"\n    hparams = {}\n\n    cfg = OmegaConf.to_container(object_dict[\"cfg\"])\n    model = object_dict[\"model\"]\n    trainer = object_dict[\"trainer\"]\n\n    if not trainer.logger:\n        log.warning(\"Logger not found! Skipping hyperparameter logging...\")\n        return\n\n    hparams[\"model\"] = cfg[\"model\"]\n\n    # save number of model parameters\n    hparams[\"model/params/total\"] = sum(p.numel() for p in model.parameters())\n    hparams[\"model/params/trainable\"] = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    hparams[\"model/params/non_trainable\"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)\n\n    hparams[\"data\"] = cfg[\"data\"]\n    hparams[\"trainer\"] = cfg[\"trainer\"]\n\n    hparams[\"callbacks\"] = cfg.get(\"callbacks\")\n    hparams[\"extras\"] = cfg.get(\"extras\")\n\n    hparams[\"task_name\"] = cfg.get(\"task_name\")\n    hparams[\"tags\"] = cfg.get(\"tags\")\n    hparams[\"ckpt_path\"] = cfg.get(\"ckpt_path\")\n    hparams[\"seed\"] = cfg.get(\"seed\")\n    hparams[\"execution_command\"] = f\"python {' '.join(sys.argv)}\"\n\n    # send hparams to all loggers\n    for logger in trainer.loggers:\n        logger.log_hyperparams(hparams)\n
    "},{"location":"api/utils/pylogger/","title":"Pylogger","text":"

    Code for logging on multi-GPU-friendly.

    "},{"location":"api/utils/pylogger/#src.utils.pylogger.RankedLogger","title":"RankedLogger","text":"

    Bases: LoggerAdapter

    A multi-GPU-friendly python command line logger.

    Source code in src/utils/pylogger.py
    class RankedLogger(logging.LoggerAdapter):\n    \"\"\"A multi-GPU-friendly python command line logger.\"\"\"\n\n    def __init__(\n        self,\n        name: str = __name__,\n        rank_zero_only: bool = False,\n        extra: Mapping[str, object] | None = None,\n    ) -> None:\n        \"\"\"Initializes a multi-GPU-friendly python command line logger that logs.\n\n        On all processes with their rank prefixed in the log message.\n\n        Args:\n            name: The name of the logger. Default is ``__name__``.\n            rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.\n            extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.\n        \"\"\"\n        logger = logging.getLogger(name)\n        super().__init__(logger=logger, extra=extra)\n        self.rank_zero_only = rank_zero_only\n\n    def log(self, level: int, msg: str, rank: int | None = None, *args, **kwargs) -> None:  # type: ignore\n        \"\"\"Delegate a log call to the underlying logger.\n\n        After prefixing its message with the rank\n        of the process it's being logged from. If `'rank'` is provided, then the log will only\n        occur on that rank/process.\n\n        Args:\n            level: The level to log at. Look at `logging.__init__.py` for more information.\n            msg: The message to log.\n            rank: The rank to log at.\n            args: Additional args to pass to the underlying logging function.\n            kwargs: Any additional keyword args to pass to the underlying logging function.\n        \"\"\"\n        if self.isEnabledFor(level):\n            msg, kwargs = self.process(msg, kwargs)  # type: ignore\n            current_rank = getattr(rank_zero_only, \"rank\", None)\n            if current_rank is None:\n                raise RuntimeError(\"The `rank_zero_only.rank` needs to be set before use\")  # noqa\n            msg = rank_prefixed_message(msg, current_rank)\n            if self.rank_zero_only:\n                if current_rank == 0:\n                    self.logger.log(level, msg, *args, **kwargs)\n            else:\n                if rank is None or current_rank == rank:\n                    self.logger.log(level, msg, *args, **kwargs)\n
    "},{"location":"api/utils/pylogger/#src.utils.pylogger.RankedLogger.__init__","title":"__init__(name=__name__, rank_zero_only=False, extra=None)","text":"

    Initializes a multi-GPU-friendly python command line logger that logs.

    On all processes with their rank prefixed in the log message.

    Parameters:

    Name Type Description Default name str

    The name of the logger. Default is __name__.

    __name__ rank_zero_only bool

    Whether to force all logs to only occur on the rank zero process. Default is False.

    False extra Mapping[str, object] | None

    (Optional) A dict-like object which provides contextual information. See logging.LoggerAdapter.

    None Source code in src/utils/pylogger.py
    def __init__(\n    self,\n    name: str = __name__,\n    rank_zero_only: bool = False,\n    extra: Mapping[str, object] | None = None,\n) -> None:\n    \"\"\"Initializes a multi-GPU-friendly python command line logger that logs.\n\n    On all processes with their rank prefixed in the log message.\n\n    Args:\n        name: The name of the logger. Default is ``__name__``.\n        rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.\n        extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.\n    \"\"\"\n    logger = logging.getLogger(name)\n    super().__init__(logger=logger, extra=extra)\n    self.rank_zero_only = rank_zero_only\n
    "},{"location":"api/utils/pylogger/#src.utils.pylogger.RankedLogger.log","title":"log(level, msg, rank=None, *args, **kwargs)","text":"

    Delegate a log call to the underlying logger.

    After prefixing its message with the rank of the process it's being logged from. If 'rank' is provided, then the log will only occur on that rank/process.

    Parameters:

    Name Type Description Default level int

    The level to log at. Look at logging.__init__.py for more information.

    required msg str

    The message to log.

    required rank int | None

    The rank to log at.

    None args

    Additional args to pass to the underlying logging function.

    () kwargs

    Any additional keyword args to pass to the underlying logging function.

    {} Source code in src/utils/pylogger.py
    def log(self, level: int, msg: str, rank: int | None = None, *args, **kwargs) -> None:  # type: ignore\n    \"\"\"Delegate a log call to the underlying logger.\n\n    After prefixing its message with the rank\n    of the process it's being logged from. If `'rank'` is provided, then the log will only\n    occur on that rank/process.\n\n    Args:\n        level: The level to log at. Look at `logging.__init__.py` for more information.\n        msg: The message to log.\n        rank: The rank to log at.\n        args: Additional args to pass to the underlying logging function.\n        kwargs: Any additional keyword args to pass to the underlying logging function.\n    \"\"\"\n    if self.isEnabledFor(level):\n        msg, kwargs = self.process(msg, kwargs)  # type: ignore\n        current_rank = getattr(rank_zero_only, \"rank\", None)\n        if current_rank is None:\n            raise RuntimeError(\"The `rank_zero_only.rank` needs to be set before use\")  # noqa\n        msg = rank_prefixed_message(msg, current_rank)\n        if self.rank_zero_only:\n            if current_rank == 0:\n                self.logger.log(level, msg, *args, **kwargs)\n        else:\n            if rank is None or current_rank == rank:\n                self.logger.log(level, msg, *args, **kwargs)\n
    "},{"location":"api/utils/rich_utils/","title":"Rich utils","text":"

    Rich utils to print config tree.

    "},{"location":"api/utils/rich_utils/#src.utils.rich_utils.enforce_tags","title":"enforce_tags(cfg, save_to_file=False)","text":"

    Prompts user to input tags from command line if no tags are provided in config.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig composed by Hydra.

    required save_to_file bool

    Whether to export tags to the hydra output folder. Default is False.

    False Source code in src/utils/rich_utils.py
    @rank_zero_only\ndef enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:\n    \"\"\"Prompts user to input tags from command line if no tags are provided in config.\n\n    Args:\n        cfg: A DictConfig composed by Hydra.\n        save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.\n    \"\"\"\n    if not cfg.get(\"tags\"):\n        if \"id\" in HydraConfig().cfg.hydra.job:\n            raise ValueError(\"Specify tags before launching a multirun!\")  # noqa\n\n        log.warning(\"No tags provided in config. Prompting user to input tags...\")\n        tags = Prompt.ask(\"Enter a list of comma separated tags\", default=\"dev\")\n        tags = [t.strip() for t in tags.split(\",\") if t != \"\"]\n\n        with open_dict(cfg):\n            cfg.tags = tags\n\n        log.info(f\"Tags: {cfg.tags}\")\n\n    if save_to_file:\n        with open(Path(cfg.paths.output_dir, \"tags.log\"), \"w\") as file:\n            rich.print(cfg.tags, file=file)\n
    "},{"location":"api/utils/rich_utils/#src.utils.rich_utils.print_config_tree","title":"print_config_tree(cfg, print_order=('data', 'model', 'callbacks', 'logger', 'trainer', 'paths', 'extras'), resolve=False, save_to_file=False)","text":"

    Prints the contents of a DictConfig as a tree structure using the Rich library.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig composed by Hydra.

    required print_order Sequence[str]

    Determines in what order config components are printed. Default is ``(\"data\", \"model\",

    ('data', 'model', 'callbacks', 'logger', 'trainer', 'paths', 'extras') resolve bool

    Whether to resolve reference fields of DictConfig. Default is False.

    False save_to_file bool

    Whether to export config to the hydra output folder. Default is False.

    False Source code in src/utils/rich_utils.py
    @rank_zero_only\ndef print_config_tree(\n    cfg: DictConfig,\n    print_order: Sequence[str] = (\n        \"data\",\n        \"model\",\n        \"callbacks\",\n        \"logger\",\n        \"trainer\",\n        \"paths\",\n        \"extras\",\n    ),\n    resolve: bool = False,\n    save_to_file: bool = False,\n) -> None:\n    \"\"\"Prints the contents of a DictConfig as a tree structure using the Rich library.\n\n    Args:\n        cfg: A DictConfig composed by Hydra.\n        print_order: Determines in what order config components are printed. Default is ``(\"data\", \"model\",\n        \"callbacks\", \"logger\", \"trainer\", \"paths\", \"extras\")``.\n        resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.\n        save_to_file: Whether to export config to the hydra output folder. Default is ``False``.\n    \"\"\"\n    style = \"dim\"\n    tree = rich.tree.Tree(\"CONFIG\", style=style, guide_style=style)\n\n    queue = []\n\n    # add fields from `print_order` to queue\n    for field in print_order:\n        queue.append(field) if field in cfg else log.warning(\n            f\"Field '{field}' not found in config. Skipping '{field}' config printing...\"\n        )\n\n    # add all the other fields to queue (not specified in `print_order`)\n    for field in cfg:\n        if field not in queue:\n            queue.append(field)\n\n    # generate config tree from queue\n    for field in queue:\n        branch = tree.add(field, style=style, guide_style=style)\n\n        config_group = cfg[field]\n        if isinstance(config_group, DictConfig):\n            branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)\n        else:\n            branch_content = str(config_group)\n\n        branch.add(rich.syntax.Syntax(branch_content, \"yaml\"))\n\n    # print config tree\n    rich.print(tree)\n\n    # save config tree to file\n    if save_to_file:\n        with open(Path(cfg.paths.output_dir, \"config_tree.log\"), \"w\") as file:\n            rich.print(tree, file=file)\n
    "},{"location":"api/utils/utils/","title":"Utils","text":"

    Utility functions for various tasks.

    "},{"location":"api/utils/utils/#src.utils.utils.extras","title":"extras(cfg)","text":"

    Applies optional utilities before the task is started.

    Utilities

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig object containing the config tree.

    required Source code in src/utils/utils.py
    def extras(cfg: DictConfig) -> None:\n    \"\"\"Applies optional utilities before the task is started.\n\n    Utilities:\n        - Ignoring python warnings\n        - Setting tags from command line\n        - Rich config printing\n\n    Args:\n        cfg: A DictConfig object containing the config tree.\n    \"\"\"\n    # return if no `extras` config\n    if not cfg.get(\"extras\"):\n        log.warning(\"Extras config not found! <cfg.extras=null>\")\n        return\n\n    # disable python warnings\n    if cfg.extras.get(\"ignore_warnings\"):\n        log.info(\"Disabling python warnings! <cfg.extras.ignore_warnings=True>\")\n        warnings.filterwarnings(\"ignore\")\n\n    # prompt user to input tags from command line if none are provided in the config\n    if cfg.extras.get(\"enforce_tags\"):\n        log.info(\"Enforcing tags! <cfg.extras.enforce_tags=True>\")\n        rich_utils.enforce_tags(cfg, save_to_file=True)\n\n    # pretty print config tree using Rich library\n    if cfg.extras.get(\"print_config\"):\n        log.info(\"Printing config tree with Rich! <cfg.extras.print_config=True>\")\n        rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)\n
    "},{"location":"api/utils/utils/#src.utils.utils.fetch_data","title":"fetch_data(url)","text":"

    Fetches data from a URL.

    Source code in src/utils/utils.py
    def fetch_data(url):\n    \"\"\"Fetches data from a URL.\"\"\"\n    response = requests.get(url)\n    if response.status_code == 200:\n        return response.json()\n    return None\n
    "},{"location":"api/utils/utils/#src.utils.utils.file_lock","title":"file_lock(filename, mode='r')","text":"

    This context manager is used to acquire a file lock on a file.

    particularly useful for shared resources in multi-process environments (multi GPU/TPU training).

    Parameters:

    Name Type Description Default filename Path

    Path to the file to lock

    required mode str

    The mode to open the file with, either \"r\" or \"w\"

    'r'

    Raises:

    Type Description ValueError

    If the mode is invalid (neither \"r\" nor \"w\")

    Source code in src/utils/utils.py
    @contextlib.contextmanager\ndef file_lock(filename: Path, mode: str = \"r\") -> Any:\n    \"\"\"This context manager is used to acquire a file lock on a file.\n\n    particularly useful for shared resources in multi-process environments (multi GPU/TPU training).\n\n    Args:\n        filename: Path to the file to lock\n        mode: The mode to open the file with, either \"r\" or \"w\"\n\n    Raises:\n        ValueError: If the mode is invalid (neither \"r\" nor \"w\")\n    \"\"\"\n    with open(filename, mode) as f:\n        try:\n            match mode:\n                case \"r\":\n                    fcntl.flock(f.fileno(), fcntl.LOCK_SH)\n                case \"w\":\n                    fcntl.flock(f.fileno(), fcntl.LOCK_EX)\n                case _:\n                    raise ValueError(\"Expected mode 'r' or 'w'.\")  # noqa\n            yield f\n        finally:\n            fcntl.flock(f.fileno(), fcntl.LOCK_UN)\n
    "},{"location":"api/utils/utils/#src.utils.utils.file_lock_operation","title":"file_lock_operation(file_name, operation)","text":"

    This function is used to perform an operation on a file while acquiring a lock on it.

    The lock is acquired using the file_lock context manager, and based on a file stored in a temporary folder

    Parameters:

    Name Type Description Default file_name str

    Path to the file to lock

    required operation Callable

    The operation to perform on the file

    required

    Returns:

    Type Description Any

    The result of the operation

    Source code in src/utils/utils.py
    @contextlib.contextmanager\ndef file_lock_operation(file_name: str, operation: Callable) -> Any:\n    \"\"\"This function is used to perform an operation on a file while acquiring a lock on it.\n\n    The lock is acquired using the `file_lock` context manager, and based on a file stored in a temporary folder\n\n    Args:\n        file_name: Path to the file to lock\n        operation: The operation to perform on the file\n\n    Returns:\n        The result of the operation\n    \"\"\"\n    with tempfile.TemporaryDirectory() as temp_dir:\n        file_path = Path(temp_dir) / file_name\n        with file_lock(file_path, mode=\"w\"):\n            result = operation(file_path)\n        return result\n
    "},{"location":"api/utils/utils/#src.utils.utils.get_metric_value","title":"get_metric_value(metric_dict, metric_name)","text":"

    Safely retrieves value of the metric logged in LightningModule.

    Parameters:

    Name Type Description Default metric_dict dict[str, Any]

    A dict containing metric values.

    required metric_name str | None

    If provided, the name of the metric to retrieve.

    required

    Returns:

    Type Description None | float

    If a metric name was provided, the value of the metric.

    Source code in src/utils/utils.py
    def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> None | float:\n    \"\"\"Safely retrieves value of the metric logged in LightningModule.\n\n    Args:\n        metric_dict: A dict containing metric values.\n        metric_name: If provided, the name of the metric to retrieve.\n\n    Returns:\n        If a metric name was provided, the value of the metric.\n    \"\"\"\n    if not metric_name:\n        log.info(\"Metric name is None! Skipping metric value retrieval...\")\n        return None\n\n    if metric_name not in metric_dict:\n        raise ValueError(f\"Metric value not found! <metric_name={metric_name}>\\n\")  # noqa: TRY003\n\n    metric_value = metric_dict[metric_name].item()\n    log.info(f\"Retrieved metric value! <{metric_name}={metric_value}>\")\n\n    return metric_value\n
    "},{"location":"api/utils/utils/#src.utils.utils.process_data","title":"process_data(url)","text":"

    Fetches data from a URL and processes it.

    Source code in src/utils/utils.py
    def process_data(url):\n    \"\"\"Fetches data from a URL and processes it.\"\"\"\n    data = fetch_data(url)\n    if data:\n        return len(data)  # Just an example of processing, counting data length\n    return 0\n
    "},{"location":"api/utils/utils/#src.utils.utils.task_wrapper","title":"task_wrapper(task_func)","text":"

    Optional decorator that controls the failure behavior when executing the task function.

    This wrapper can be used to

    Example:

    @utils.task_wrapper\ndef train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n    ...\n    return metric_dict, object_dict\n

    Parameters:

    Name Type Description Default task_func Callable

    The task function to be wrapped.

    required

    Returns:

    Type Description Callable

    The wrapped task function.

    Source code in src/utils/utils.py
    def task_wrapper(task_func: Callable) -> Callable:\n    \"\"\"Optional decorator that controls the failure behavior when executing the task function.\n\n    This wrapper can be used to:\n        - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)\n        - save the exception to a `.log` file\n        - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)\n        - etc. (adjust depending on your needs)\n\n    Example:\n    ```\n    @utils.task_wrapper\n    def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n        ...\n        return metric_dict, object_dict\n    ```\n\n    Args:\n        task_func: The task function to be wrapped.\n\n    Returns:\n        The wrapped task function.\n    \"\"\"\n\n    def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:\n        # execute the task\n        try:\n            metric_dict, object_dict = task_func(cfg=cfg)\n\n        # things to do if exception occurs\n        except Exception as e:\n            # save exception to `.log` file\n            log.exception(\"\")\n\n            # some hyperparameter combinations might be invalid or cause out-of-memory errors\n            # so when using hparam search plugins like Optuna, you might want to disable\n            # raising the below exception to avoid multirun failure\n            raise e  # noqa: TRY201\n\n        # things to always do after either success or exception\n        finally:\n            # display output dir path in terminal\n            log.info(f\"Output dir: {cfg.paths.output_dir}\")\n\n            # always close wandb run (even if exception occurs so multirun won't fail)\n            if find_spec(\"wandb\"):  # check if wandb is installed\n                import wandb\n\n                if wandb.run:\n                    log.info(\"Closing wandb!\")\n                    wandb.finish()\n\n        return metric_dict, object_dict\n\n    return wrap\n
    "}]} \ No newline at end of file +{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Home","text":"Machine Learning Project Template [![python](https://img.shields.io/badge/-Python_3.8_%7C_3.9_%7C_3.10-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit) [![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) [![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/) [![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/) [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) [![isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) Click on [Use this template](https://github.com/rayanramoul/ml-project-template/generate) to start your own project! or go to the [Documentation](https://rayanramoul.github.io/ml-project-template/) for more information. A template for machine learning or deep learning projects."},{"location":"#features","title":"\ud83e\udde0 Features","text":""},{"location":"#steps-for-installation","title":"\u2699\ufe0f Steps for Installation","text":""},{"location":"#tips-and-tricks","title":"\ud83e\udd20Tips and Tricks","text":""},{"location":"#how-does-the-project-work","title":"\ud83d\udc0d How does the project work?","text":"

    The train.py or eval.py script is the entry point of the project. It uses Hydra to instantiate the model (LightningModule), dataloader (DataModule), and trainer using the configuration reconstructed using Hydra. The model is then trained or evaluated using Pytorch Lightning.

    "},{"location":"#implementing-your-logic","title":"Implementing your logic","text":"

    You don't need to worry about implementing the training loops, the support for different hardwares, reading of configurations, etc. You need to care about 4 files for each training : your LightningModule (+ its hydra config), your DataModule (+ its hydra config).

    In the LightningModule, you need to implement the following methods:

    Get inspired by the provided examples in the src/data folder.

    Get to know more about Pytorch Lightning's LightningModule and DataModule in the Pytorch Lightning documentation. Finally in the associated configs/ folder, you need to implement the yaml configuration files for the model and dataloader.

    "},{"location":"#the-power-of-hydra","title":"\ud83d\udd0d The power of Hydra","text":"

    As Hydra is used for configuration, you can easily change the hyperparameters of your model, the dataloader, the trainer, etc. by changing the yaml configuration files in the configs/ folder. You can also use the --multirun option to run multiple experiments with different configurations.

    But also, as it used to instantiate the model and dataloader, you can easily change the model, dataloader, or any other component by changing the yaml configuration files or DIRECTLY IN COMMAND LINE. This is especially useful when you want to use different models or dataloaders.

    For example, you can run the following command to train a model with a different architecture, changing the dataset used, and the trainer used:

    uv run src/train.py model=LeNet datamodule=MNISTDataModule trainer=gpu\n

    Read more about Hydra in the official documentation.

    "},{"location":"#best-practices","title":"\ud83d\udca1 Best practices","text":""},{"location":"#documentation","title":"\ud83d\udcda Documentation","text":"

    You have the possibility to generate a documentation website using Mkdocs. It will automatically generate the documentation based on both the markdown files in the docs/ folder and the docstrings in your code. To generate and serve the documentation locally:

    make serve-docs # Documentation will be available at http://localhost:8000\n

    And to deploy it to Github pages (youn need to enable Pages in your repository configuration and set it to use the gh-pages branch):

    make pages-deploy # It will create a gh-pages branch and push the documentation to it\n
    "},{"location":"#tree-explained","title":"\ud83c\udf33 Tree Explained","text":"

    ``` . \u251c\u2500\u2500 commit-template.txt # use this file to set your commit message template, with make configure-commit template \u251c\u2500\u2500 configs # configuration files for hydra \u2502\u00a0\u00a0 \u251c\u2500\u2500 callbacks # configuration files for callbacks \u2502\u00a0\u00a0 \u251c\u2500\u2500 data # configuration files for datamodules \u2502\u00a0\u00a0 \u251c\u2500\u2500 debug # configuration files for pytorch lightning debuggers \u2502\u00a0\u00a0 \u251c\u2500\u2500 eval.yaml # configuration file for evaluation \u2502\u00a0\u00a0 \u251c\u2500\u2500 experiment # configuration files for experiments \u2502\u00a0\u00a0 \u251c\u2500\u2500 extras # configuration files for extra components \u2502\u00a0\u00a0 \u251c\u2500\u2500 hparams_search # configuration files for hyperparameters search \u2502\u00a0\u00a0 \u251c\u2500\u2500 local # configuration files for local training \u2502\u00a0\u00a0 \u251c\u2500\u2500 logger # configuration files for loggers (neptune, wandb, etc.) \u2502\u00a0\u00a0 \u251c\u2500\u2500 model # configuration files for models (LightningModule) \u2502\u00a0\u00a0 \u251c\u2500\u2500 paths # configuration files for paths \u2502\u00a0\u00a0 \u251c\u2500\u2500 trainer # configuration files for trainers (cpu, gpu, tpu) \u2502\u00a0\u00a0 \u2514\u2500\u2500 train.yaml # configuration file for training \u251c\u2500\u2500 data # data folder (to store potentially downloaded datasets) \u251c\u2500\u2500 Makefile # makefile contains useful commands for the project \u251c\u2500\u2500 notebooks # notebooks folder \u251c\u2500\u2500 pyproject.toml # pyproject.toml file for uv package manager \u251c\u2500\u2500 README.md # this file \u251c\u2500\u2500 ruff.toml # ruff.toml file for pre-commit \u251c\u2500\u2500 scripts # scripts folder \u2502\u00a0\u00a0 \u2514\u2500\u2500 example_train.sh \u251c\u2500\u2500 src # source code folder \u2502\u00a0\u00a0 \u251c\u2500\u2500 data # datamodules folder \u2502\u00a0\u00a0 \u2502\u00a0\u00a0 \u251c\u2500\u2500 components \u2502\u00a0\u00a0 \u2502\u00a0\u00a0 \u2514\u2500\u2500 mnist_datamodule.py \u2502\u00a0\u00a0 \u251c\u2500\u2500 eval.py # evaluation entry script \u2502\u00a0\u00a0 \u251c\u2500\u2500 models # models folder (LightningModule) \u2502\u00a0\u00a0 \u2502\u00a0\u00a0 \u251c\u2500\u2500 components # components folder, contains model parts or \"nets\" \u2502\u00a0\u00a0 \u251c\u2500\u2500 train.py # training entry script \u2502\u00a0\u00a0 \u2514\u2500\u2500 utils # utils folder \u2502\u00a0\u00a0 \u251c\u2500\u2500 instantiators.py # instantiators for models and dataloaders \u2502\u00a0\u00a0 \u251c\u2500\u2500 logging_utils.py # logger utils \u2502\u00a0\u00a0 \u251c\u2500\u2500 pylogger.py # multi-process and multi-gpu safe logging \u2502\u00a0\u00a0 \u251c\u2500\u2500 rich_utils.py # rich utils \u2502\u00a0\u00a0 \u2514\u2500\u2500 utils.py # general utils like multi-processing, etc. \u2514\u2500\u2500 tests # tests folder \u2514\u2500\u2500 conftest.py # fixtures for tests \u2514\u2500\u2500 mock_test.py # example of mocking tests

    ````

    "},{"location":"#contributing","title":"\ud83e\udd1d Contributing","text":"

    For more information on how to contribute to this project, please refer to the CONTRIBUTING.md file.

    "},{"location":"#aknowledgements","title":"\ud83c\udf1f Aknowledgements","text":"

    This template was heavily inspired by great existing ones, like:

    But with a few opininated changes and improvements, go check them out!

    "},{"location":"api/eval/","title":"Eval","text":"

    Main evaluation script.

    "},{"location":"api/eval/#src.eval.evaluate","title":"evaluate(cfg)","text":"

    Evaluates given checkpoint on a datamodule testset.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.

    Parameters:

    Name Type Description Default cfg DictConfig

    DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description tuple[dict[str, Any], dict[str, Any]]

    tuple[dict, dict] with metrics and dict with all instantiated objects.

    Source code in src/eval.py
    @task_wrapper\ndef evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:\n    \"\"\"Evaluates given checkpoint on a datamodule testset.\n\n    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during\n    failure. Useful for multiruns, saving info about the crash, etc.\n\n    Args:\n        cfg: DictConfig configuration composed by Hydra.\n\n    Returns:\n        tuple[dict, dict] with metrics and dict with all instantiated objects.\n    \"\"\"\n    assert cfg.ckpt_path\n\n    log.info(f\"Instantiating datamodule <{cfg.data._target_}>\")\n    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)\n\n    log.info(f\"Instantiating model <{cfg.model._target_}>\")\n    model: LightningModule = hydra.utils.instantiate(cfg.model)\n\n    if cfg.get(\"model_compile\", False):\n        log.info(\"Compiling model...\")\n        torch.compile(model)\n\n    log.info(\"Instantiating loggers...\")\n    logger: list[Logger] = instantiate_loggers(cfg.get(\"logger\"))\n\n    log.info(f\"Instantiating trainer <{cfg.trainer._target_}>\")\n    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)\n\n    object_dict = {\n        \"cfg\": cfg,\n        \"datamodule\": datamodule,\n        \"model\": model,\n        \"logger\": logger,\n        \"trainer\": trainer,\n    }\n\n    if logger:\n        log.info(\"Logging hyperparameters!\")\n        log_hyperparameters(object_dict)\n\n    log.info(\"Starting testing!\")\n    trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)\n\n    # for predictions use trainer.predict(...)\n    # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)\n\n    metric_dict = trainer.callback_metrics\n\n    return metric_dict, object_dict\n
    "},{"location":"api/eval/#src.eval.main","title":"main(cfg)","text":"

    Main entry point for evaluation.

    :param cfg: DictConfig configuration composed by Hydra.

    Source code in src/eval.py
    @hydra.main(version_base=\"1.3\", config_path=\"../configs\", config_name=\"eval.yaml\")\ndef main(cfg: DictConfig) -> None:\n    \"\"\"Main entry point for evaluation.\n\n    :param cfg: DictConfig configuration composed by Hydra.\n    \"\"\"\n    # apply extra utilities\n    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)\n    extras(cfg)\n\n    evaluate(cfg)\n
    "},{"location":"api/serve/","title":"Serve","text":"

    Main serve script.

    "},{"location":"api/serve/#src.serve.main","title":"main(cfg)","text":"

    Main entry point for serving.

    Parameters:

    Name Type Description Default cfg DictConfig

    DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description None

    Optional[float] with optimized metric value.

    Source code in src/serve.py
    @hydra.main(version_base=\"1.3\", config_path=\"../configs\", config_name=\"serve.yaml\")\ndef main(cfg: DictConfig) -> None:\n    \"\"\"Main entry point for serving.\n\n    Args:\n        cfg: DictConfig configuration composed by Hydra.\n\n    Returns:\n        Optional[float] with optimized metric value.\n    \"\"\"\n    serve(cfg)\n
    "},{"location":"api/serve/#src.serve.serve","title":"serve(cfg)","text":"

    Serve the specified model in the configuration as a FastAPI api.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description None

    A tuple with metrics and dict with all instantiated objects.

    Source code in src/serve.py
    @task_wrapper\ndef serve(cfg: DictConfig) -> None:\n    \"\"\"Serve the specified model in the configuration as a FastAPI api.\n\n    Args:\n        cfg: A DictConfig configuration composed by Hydra.\n\n    Returns:\n        A tuple with metrics and dict with all instantiated objects.\n    \"\"\"\n    # set seed for random number generators in pytorch, numpy and python.random\n    if cfg.get(\"seed\"):\n        lightning.seed_everything(cfg.seed, workers=True)\n    log.info(f\"Getting model class <{cfg.model._target_}>\")\n    model_class = hydra.utils.get_class(cfg.model._target_)\n    lit_server_api = hydra.utils.instantiate(cfg.serve.api, model_class=model_class)\n    # Create the LitServe server with the MNISTServeAPI\n    server = ls.LitServer(lit_server_api, accelerator=cfg.serve.accelerator, max_batch_size=cfg.serve.max_batch_size)\n    log.info(\"Initialized LitServe server\")\n    # Run the server on port 8000\n    log.info(f\"Starting LitServe server on port {cfg.serve.port}\")\n    server.run(port=cfg.serve.port)\n
    "},{"location":"api/train/","title":"Train","text":"

    Main training script.

    "},{"location":"api/train/#src.train.main","title":"main(cfg)","text":"

    Main entry point for training.

    Parameters:

    Name Type Description Default cfg DictConfig

    DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description float | None

    Optional[float] with optimized metric value.

    Source code in src/train.py
    @hydra.main(version_base=\"1.3\", config_path=\"../configs\", config_name=\"train.yaml\")\ndef main(cfg: DictConfig) -> float | None:\n    \"\"\"Main entry point for training.\n\n    Args:\n        cfg: DictConfig configuration composed by Hydra.\n\n    Returns:\n        Optional[float] with optimized metric value.\n    \"\"\"\n    # apply extra utilities\n    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)\n    extras(cfg)\n\n    # train the model\n    metric_dict, _ = train(cfg)\n\n    # safely retrieve metric value for hydra-based hyperparameter optimization\n    metric_value = get_metric_value(metric_dict=metric_dict, metric_name=cfg.get(\"optimized_metric\"))\n\n    # return optimized metric\n    return metric_value\n
    "},{"location":"api/train/#src.train.train","title":"train(cfg)","text":"

    Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig configuration composed by Hydra.

    required

    Returns:

    Type Description tuple[dict[str, Any], dict[str, Any]]

    A tuple with metrics and dict with all instantiated objects.

    Source code in src/train.py
    @task_wrapper\ndef train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:\n    \"\"\"Trains the model. Can additionally evaluate on a testset, using best weights obtained during training.\n\n    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during\n    failure. Useful for multiruns, saving info about the crash, etc.\n\n    Args:\n        cfg: A DictConfig configuration composed by Hydra.\n\n    Returns:\n        A tuple with metrics and dict with all instantiated objects.\n    \"\"\"\n    # set seed for random number generators in pytorch, numpy and python.random\n    if cfg.get(\"seed\"):\n        lightning.seed_everything(cfg.seed, workers=True)\n\n    log.info(f\"Instantiating datamodule <{cfg.data._target_}>\")\n    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)\n\n    log.info(f\"Instantiating model <{cfg.model._target_}>\")\n    model: LightningModule = hydra.utils.instantiate(cfg.model)\n\n    if cfg.get(\"model_compile\", False):\n        log.info(\"Compiling model...\")\n        torch.compile(model)\n\n    log.info(\"Instantiating callbacks...\")\n    callbacks: list[Callback] = instantiate_callbacks(cfg.get(\"callbacks\"))\n\n    log.info(\"Instantiating loggers...\")\n    logger: list[Logger] = instantiate_loggers(cfg.get(\"logger\"))\n\n    log.info(f\"Instantiating trainer <{cfg.trainer._target_}>\")\n    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)\n\n    object_dict = {\n        \"cfg\": cfg,\n        \"datamodule\": datamodule,\n        \"model\": model,\n        \"callbacks\": callbacks,\n        \"logger\": logger,\n        \"trainer\": trainer,\n    }\n\n    if logger:\n        log.info(\"Logging hyperparameters!\")\n        log_hyperparameters(object_dict)\n\n    if cfg.get(\"train\"):\n        log.info(\"Starting training!\")\n        trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get(\"ckpt_path\"))\n\n    train_metrics = trainer.callback_metrics\n\n    if cfg.get(\"test\"):\n        log.info(\"Starting testing!\")\n        ckpt_path = trainer.checkpoint_callback.best_model_path\n        if ckpt_path == \"\":\n            log.warning(\"Best ckpt not found! Using current weights for testing...\")\n            ckpt_path = None\n        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)\n        log.info(f\"Best ckpt path: {ckpt_path}\")\n\n    test_metrics = trainer.callback_metrics\n\n    # merge train and test metrics\n    metric_dict = {**train_metrics, **test_metrics}\n\n    return metric_dict, object_dict\n
    "},{"location":"api/data/mnist_datamodule/","title":"Mnist datamodule","text":"

    MNIST DataModule.

    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule","title":"MNISTDataModule","text":"

    Bases: LightningDataModule

    LightningDataModule for the MNIST dataset.

    The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.

    A LightningDataModule implements 7 key methods:

        def prepare_data(self):\n    # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).\n    # Download data, pre-process, split, save to disk, etc...\n\n    def setup(self, stage):\n    # Things to do on every process in DDP.\n    # Load data, set variables, etc...\n\n    def train_dataloader(self):\n    # return train dataloader\n\n    def val_dataloader(self):\n    # return validation dataloader\n\n    def test_dataloader(self):\n    # return test dataloader\n\n    def predict_dataloader(self):\n    # return predict dataloader\n\n    def teardown(self, stage):\n    # Called on every process in DDP.\n    # Clean up after fit or test.\n

    This allows you to share a full dataset without explaining how to download, split, transform and process the data.

    Read the docs

    https://lightning.ai/docs/pytorch/latest/data/datamodule.html

    Source code in src/data/mnist_datamodule.py
    class MNISTDataModule(LightningDataModule):\n    \"\"\"`LightningDataModule` for the MNIST dataset.\n\n    The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.\n    It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a\n    fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box\n    while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing\n    technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of\n    mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.\n\n    A `LightningDataModule` implements 7 key methods:\n\n    ```python\n        def prepare_data(self):\n        # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).\n        # Download data, pre-process, split, save to disk, etc...\n\n        def setup(self, stage):\n        # Things to do on every process in DDP.\n        # Load data, set variables, etc...\n\n        def train_dataloader(self):\n        # return train dataloader\n\n        def val_dataloader(self):\n        # return validation dataloader\n\n        def test_dataloader(self):\n        # return test dataloader\n\n        def predict_dataloader(self):\n        # return predict dataloader\n\n        def teardown(self, stage):\n        # Called on every process in DDP.\n        # Clean up after fit or test.\n    ```\n\n    This allows you to share a full dataset without explaining how to download,\n    split, transform and process the data.\n\n    Read the docs:\n        https://lightning.ai/docs/pytorch/latest/data/datamodule.html\n    \"\"\"\n\n    def __init__(\n        self,\n        data_dir: str = \"data/\",\n        train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000),\n        batch_size: int = 64,\n        num_workers: int = 0,\n        pin_memory: bool = False,\n    ) -> None:\n        \"\"\"Initialize a `MNISTDataModule`.\n\n        Args:\n            data_dir: The data directory. Defaults to `\"data/\"`.\n            train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.\n            batch_size: The batch size. Defaults to `64`.\n            num_workers: The number of workers. Defaults to `0`.\n            pin_memory: Whether to pin memory. Defaults to `False`.\n        \"\"\"\n        super().__init__()\n\n        # this line allows to access init params with 'self.hparams' attribute\n        # also ensures init params will be stored in ckpt\n        self.save_hyperparameters(logger=False)\n\n        # data transformations\n        self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n\n        self.data_train: Dataset | None = None\n        self.data_val: Dataset | None = None\n        self.data_test: Dataset | None = None\n\n        self.batch_size_per_device = batch_size\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"Get the number of classes.\n\n        :return: The number of MNIST classes (10).\n        \"\"\"\n        return 10\n\n    def prepare_data(self) -> None:\n        \"\"\"Download data if needed.\n\n        Lightning ensures that `self.prepare_data()` is called only\n        within a single process on CPU, so you can safely add your downloading logic within. In\n        case of multi-node training, the execution of this hook depends upon\n        `self.prepare_data_per_node()`.\n\n        Do not use it to assign state (self.x = y).\n        \"\"\"\n        MNIST(self.hparams.data_dir, train=True, download=True)\n        MNIST(self.hparams.data_dir, train=False, download=True)\n\n    def setup(self, stage: str | None = None) -> None:\n        \"\"\"Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.\n\n        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and\n        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after\n        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to\n        `self.setup()` once the data is prepared and available for use.\n\n        Args:\n            stage: The stage to setup. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`. Defaults to ``None``.\n        \"\"\"\n        # Divide batch size by the number of devices.\n        if self.trainer is not None:\n            if self.hparams.batch_size % self.trainer.world_size != 0:\n                raise RuntimeError(  # noqa\n                    f\"Batch size ({self.hparams.batch_size}) \"\n                    \"is not divisible by the number of devices ({self.trainer.world_size}).\"\n                )\n            self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size\n\n        # load and split datasets only if not loaded already\n        if not self.data_train and not self.data_val and not self.data_test:\n            trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)\n            testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)\n            dataset = ConcatDataset(datasets=[trainset, testset])\n            self.data_train, self.data_val, self.data_test = random_split(\n                dataset=dataset,\n                lengths=self.hparams.train_val_test_split,\n                generator=torch.Generator().manual_seed(42),\n            )\n\n    def train_dataloader(self) -> DataLoader[Any]:\n        \"\"\"Create and return the train dataloader.\n\n        Returns:\n            The train dataloader.\n        \"\"\"\n        return DataLoader(\n            dataset=self.data_train,\n            batch_size=self.batch_size_per_device,\n            num_workers=self.hparams.num_workers,\n            pin_memory=self.hparams.pin_memory,\n            shuffle=True,\n        )\n\n    def val_dataloader(self) -> DataLoader[Any]:\n        \"\"\"Create and return the validation dataloader.\n\n        Returns:\n            The validation dataloader.\n        \"\"\"\n        return DataLoader(\n            dataset=self.data_val,\n            batch_size=self.batch_size_per_device,\n            num_workers=self.hparams.num_workers,\n            pin_memory=self.hparams.pin_memory,\n            shuffle=False,\n        )\n\n    def test_dataloader(self) -> DataLoader[Any]:\n        \"\"\"Create and return the test dataloader.\n\n        Returns:\n            The test dataloader.\n        \"\"\"\n        return DataLoader(\n            dataset=self.data_test,\n            batch_size=self.batch_size_per_device,\n            num_workers=self.hparams.num_workers,\n            pin_memory=self.hparams.pin_memory,\n            shuffle=False,\n        )\n\n    def teardown(self, stage: str | None = None) -> None:\n        \"\"\"Lightning hook for cleaning up after trainer main functions.\n\n        `trainer.fit()`, `trainer.validate()`,`trainer.test()`, and `trainer.predict()`.\n\n        Args:\n            stage: The stage being torn down. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n            Defaults to ``None``.\n        \"\"\"\n        pass\n\n    def state_dict(self) -> dict[Any, Any]:\n        \"\"\"Called when saving a checkpoint. Implement to generate and save the datamodule state.\n\n        Returns:\n            A dictionary containing the datamodule state that you want to save.\n        \"\"\"\n        return {}\n\n    def load_state_dict(self, state_dict: dict[str, Any]) -> None:\n        \"\"\"Called when loading a checkpoint. Implement to reload datamodule state given datamodule `state_dict()`.\n\n        Args:\n            state_dict: The datamodule state returned by `self.state_dict()`.\n        \"\"\"\n        pass\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.num_classes","title":"num_classes: int property","text":"

    Get the number of classes.

    :return: The number of MNIST classes (10).

    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.__init__","title":"__init__(data_dir='data/', train_val_test_split=(55000, 5000, 10000), batch_size=64, num_workers=0, pin_memory=False)","text":"

    Initialize a MNISTDataModule.

    Parameters:

    Name Type Description Default data_dir str

    The data directory. Defaults to \"data/\".

    'data/' train_val_test_split tuple[int, int, int]

    The train, validation and test split. Defaults to (55_000, 5_000, 10_000).

    (55000, 5000, 10000) batch_size int

    The batch size. Defaults to 64.

    64 num_workers int

    The number of workers. Defaults to 0.

    0 pin_memory bool

    Whether to pin memory. Defaults to False.

    False Source code in src/data/mnist_datamodule.py
    def __init__(\n    self,\n    data_dir: str = \"data/\",\n    train_val_test_split: tuple[int, int, int] = (55_000, 5_000, 10_000),\n    batch_size: int = 64,\n    num_workers: int = 0,\n    pin_memory: bool = False,\n) -> None:\n    \"\"\"Initialize a `MNISTDataModule`.\n\n    Args:\n        data_dir: The data directory. Defaults to `\"data/\"`.\n        train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.\n        batch_size: The batch size. Defaults to `64`.\n        num_workers: The number of workers. Defaults to `0`.\n        pin_memory: Whether to pin memory. Defaults to `False`.\n    \"\"\"\n    super().__init__()\n\n    # this line allows to access init params with 'self.hparams' attribute\n    # also ensures init params will be stored in ckpt\n    self.save_hyperparameters(logger=False)\n\n    # data transformations\n    self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n\n    self.data_train: Dataset | None = None\n    self.data_val: Dataset | None = None\n    self.data_test: Dataset | None = None\n\n    self.batch_size_per_device = batch_size\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.load_state_dict","title":"load_state_dict(state_dict)","text":"

    Called when loading a checkpoint. Implement to reload datamodule state given datamodule state_dict().

    Parameters:

    Name Type Description Default state_dict dict[str, Any]

    The datamodule state returned by self.state_dict().

    required Source code in src/data/mnist_datamodule.py
    def load_state_dict(self, state_dict: dict[str, Any]) -> None:\n    \"\"\"Called when loading a checkpoint. Implement to reload datamodule state given datamodule `state_dict()`.\n\n    Args:\n        state_dict: The datamodule state returned by `self.state_dict()`.\n    \"\"\"\n    pass\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.prepare_data","title":"prepare_data()","text":"

    Download data if needed.

    Lightning ensures that self.prepare_data() is called only within a single process on CPU, so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook depends upon self.prepare_data_per_node().

    Do not use it to assign state (self.x = y).

    Source code in src/data/mnist_datamodule.py
    def prepare_data(self) -> None:\n    \"\"\"Download data if needed.\n\n    Lightning ensures that `self.prepare_data()` is called only\n    within a single process on CPU, so you can safely add your downloading logic within. In\n    case of multi-node training, the execution of this hook depends upon\n    `self.prepare_data_per_node()`.\n\n    Do not use it to assign state (self.x = y).\n    \"\"\"\n    MNIST(self.hparams.data_dir, train=True, download=True)\n    MNIST(self.hparams.data_dir, train=False, download=True)\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.setup","title":"setup(stage=None)","text":"

    Load data. Set variables: self.data_train, self.data_val, self.data_test.

    This method is called by Lightning before trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict(), so be careful not to execute things like random split twice! Also, it is called after self.prepare_data() and there is a barrier in between which ensures that all the processes proceed to self.setup() once the data is prepared and available for use.

    Parameters:

    Name Type Description Default stage str | None

    The stage to setup. Either \"fit\", \"validate\", \"test\", or \"predict\". Defaults to None.

    None Source code in src/data/mnist_datamodule.py
    def setup(self, stage: str | None = None) -> None:\n    \"\"\"Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.\n\n    This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and\n    `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after\n    `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to\n    `self.setup()` once the data is prepared and available for use.\n\n    Args:\n        stage: The stage to setup. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`. Defaults to ``None``.\n    \"\"\"\n    # Divide batch size by the number of devices.\n    if self.trainer is not None:\n        if self.hparams.batch_size % self.trainer.world_size != 0:\n            raise RuntimeError(  # noqa\n                f\"Batch size ({self.hparams.batch_size}) \"\n                \"is not divisible by the number of devices ({self.trainer.world_size}).\"\n            )\n        self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size\n\n    # load and split datasets only if not loaded already\n    if not self.data_train and not self.data_val and not self.data_test:\n        trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)\n        testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)\n        dataset = ConcatDataset(datasets=[trainset, testset])\n        self.data_train, self.data_val, self.data_test = random_split(\n            dataset=dataset,\n            lengths=self.hparams.train_val_test_split,\n            generator=torch.Generator().manual_seed(42),\n        )\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.state_dict","title":"state_dict()","text":"

    Called when saving a checkpoint. Implement to generate and save the datamodule state.

    Returns:

    Type Description dict[Any, Any]

    A dictionary containing the datamodule state that you want to save.

    Source code in src/data/mnist_datamodule.py
    def state_dict(self) -> dict[Any, Any]:\n    \"\"\"Called when saving a checkpoint. Implement to generate and save the datamodule state.\n\n    Returns:\n        A dictionary containing the datamodule state that you want to save.\n    \"\"\"\n    return {}\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.teardown","title":"teardown(stage=None)","text":"

    Lightning hook for cleaning up after trainer main functions.

    trainer.fit(), trainer.validate(),trainer.test(), and trainer.predict().

    Parameters:

    Name Type Description Default stage str | None

    The stage being torn down. Either \"fit\", \"validate\", \"test\", or \"predict\".

    None Source code in src/data/mnist_datamodule.py
    def teardown(self, stage: str | None = None) -> None:\n    \"\"\"Lightning hook for cleaning up after trainer main functions.\n\n    `trainer.fit()`, `trainer.validate()`,`trainer.test()`, and `trainer.predict()`.\n\n    Args:\n        stage: The stage being torn down. Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n        Defaults to ``None``.\n    \"\"\"\n    pass\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.test_dataloader","title":"test_dataloader()","text":"

    Create and return the test dataloader.

    Returns:

    Type Description DataLoader[Any]

    The test dataloader.

    Source code in src/data/mnist_datamodule.py
    def test_dataloader(self) -> DataLoader[Any]:\n    \"\"\"Create and return the test dataloader.\n\n    Returns:\n        The test dataloader.\n    \"\"\"\n    return DataLoader(\n        dataset=self.data_test,\n        batch_size=self.batch_size_per_device,\n        num_workers=self.hparams.num_workers,\n        pin_memory=self.hparams.pin_memory,\n        shuffle=False,\n    )\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.train_dataloader","title":"train_dataloader()","text":"

    Create and return the train dataloader.

    Returns:

    Type Description DataLoader[Any]

    The train dataloader.

    Source code in src/data/mnist_datamodule.py
    def train_dataloader(self) -> DataLoader[Any]:\n    \"\"\"Create and return the train dataloader.\n\n    Returns:\n        The train dataloader.\n    \"\"\"\n    return DataLoader(\n        dataset=self.data_train,\n        batch_size=self.batch_size_per_device,\n        num_workers=self.hparams.num_workers,\n        pin_memory=self.hparams.pin_memory,\n        shuffle=True,\n    )\n
    "},{"location":"api/data/mnist_datamodule/#src.data.mnist_datamodule.MNISTDataModule.val_dataloader","title":"val_dataloader()","text":"

    Create and return the validation dataloader.

    Returns:

    Type Description DataLoader[Any]

    The validation dataloader.

    Source code in src/data/mnist_datamodule.py
    def val_dataloader(self) -> DataLoader[Any]:\n    \"\"\"Create and return the validation dataloader.\n\n    Returns:\n        The validation dataloader.\n    \"\"\"\n    return DataLoader(\n        dataset=self.data_val,\n        batch_size=self.batch_size_per_device,\n        num_workers=self.hparams.num_workers,\n        pin_memory=self.hparams.pin_memory,\n        shuffle=False,\n    )\n
    "},{"location":"api/data/polars_datamodule/","title":"Polars datamodule","text":"

    PyTorch Lightning DataModule for loading dataset using Polars.

    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule","title":"PolarsDataModule","text":"

    Bases: LightningDataModule

    PyTorch Lightning DataModule for loading dataset using Polars.

    Source code in src/data/polars_datamodule.py
    class PolarsDataModule(LightningDataModule):\n    \"\"\"PyTorch Lightning DataModule for loading dataset using Polars.\"\"\"\n\n    def __init__(\n        self, data_path: str, output_column: str, batch_size: int = 32, num_workers: int = 0, test_size: float = 0.2\n    ) -> None:\n        \"\"\"Initialize the PolarsDataModule.\n\n        Args:\n            data_path: Path to the dataset.\n            output_column: Column name that contains the labels.\n            batch_size: Batch size for the dataloaders.\n            num_workers: Number of workers for the dataloaders.\n            test_size: Fraction of the dataset to be used for validation.\n        \"\"\"\n        super().__init__()\n        self.data_path = data_path\n        self.output_column = output_column\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n        self.test_size = test_size\n        self.df = None  # Will hold the loaded Polars DataFrame\n\n    def setup(self, stage: str = \"\") -> None:\n        \"\"\"Load and split the dataset into train and validation sets.\"\"\"\n        # Load dataset using Polars\n        self.df = pl.read_csv(self.data_path)\n\n        # Split the data into train and validation sets\n        train_df, val_df = train_test_split(self.df, test_size=self.test_size, random_state=42)\n\n        self.train_dataset = PolarsDataset(pl.DataFrame(train_df), output_column=self.output_column)\n        self.val_dataset = PolarsDataset(pl.DataFrame(val_df), output_column=self.output_column)\n\n    def train_dataloader(self) -> DataLoader:\n        \"\"\"Create and return the train dataloader.\"\"\"\n        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)\n\n    def val_dataloader(self) -> DataLoader:\n        \"\"\"Create and return the validation dataloader.\"\"\"\n        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.__init__","title":"__init__(data_path, output_column, batch_size=32, num_workers=0, test_size=0.2)","text":"

    Initialize the PolarsDataModule.

    Parameters:

    Name Type Description Default data_path str

    Path to the dataset.

    required output_column str

    Column name that contains the labels.

    required batch_size int

    Batch size for the dataloaders.

    32 num_workers int

    Number of workers for the dataloaders.

    0 test_size float

    Fraction of the dataset to be used for validation.

    0.2 Source code in src/data/polars_datamodule.py
    def __init__(\n    self, data_path: str, output_column: str, batch_size: int = 32, num_workers: int = 0, test_size: float = 0.2\n) -> None:\n    \"\"\"Initialize the PolarsDataModule.\n\n    Args:\n        data_path: Path to the dataset.\n        output_column: Column name that contains the labels.\n        batch_size: Batch size for the dataloaders.\n        num_workers: Number of workers for the dataloaders.\n        test_size: Fraction of the dataset to be used for validation.\n    \"\"\"\n    super().__init__()\n    self.data_path = data_path\n    self.output_column = output_column\n    self.batch_size = batch_size\n    self.num_workers = num_workers\n    self.test_size = test_size\n    self.df = None  # Will hold the loaded Polars DataFrame\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.setup","title":"setup(stage='')","text":"

    Load and split the dataset into train and validation sets.

    Source code in src/data/polars_datamodule.py
    def setup(self, stage: str = \"\") -> None:\n    \"\"\"Load and split the dataset into train and validation sets.\"\"\"\n    # Load dataset using Polars\n    self.df = pl.read_csv(self.data_path)\n\n    # Split the data into train and validation sets\n    train_df, val_df = train_test_split(self.df, test_size=self.test_size, random_state=42)\n\n    self.train_dataset = PolarsDataset(pl.DataFrame(train_df), output_column=self.output_column)\n    self.val_dataset = PolarsDataset(pl.DataFrame(val_df), output_column=self.output_column)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.train_dataloader","title":"train_dataloader()","text":"

    Create and return the train dataloader.

    Source code in src/data/polars_datamodule.py
    def train_dataloader(self) -> DataLoader:\n    \"\"\"Create and return the train dataloader.\"\"\"\n    return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataModule.val_dataloader","title":"val_dataloader()","text":"

    Create and return the validation dataloader.

    Source code in src/data/polars_datamodule.py
    def val_dataloader(self) -> DataLoader:\n    \"\"\"Create and return the validation dataloader.\"\"\"\n    return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset","title":"PolarsDataset","text":"

    Bases: Dataset

    Custom PyTorch Dataset wrapping a Polars DataFrame.

    Source code in src/data/polars_datamodule.py
    class PolarsDataset(Dataset):\n    \"\"\"Custom PyTorch Dataset wrapping a Polars DataFrame.\"\"\"\n\n    def __init__(self, df: pl.DataFrame, output_column: str) -> None:\n        \"\"\"Initialize the PolarsDataset.\"\"\"\n        self.df = df\n        self.output_column = output_column\n\n    def __len__(self) -> int:\n        \"\"\"Return the number of rows in the dataset.\"\"\"\n        return self.df.shape[0]\n\n    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"Return the features and label for the given index.\"\"\"\n        row = self.df[idx]\n        features = torch.tensor([val for col, val in row.items() if col != self.output_column], dtype=torch.float32)\n        label = torch.tensor(row[self.output_column], dtype=torch.long)\n        return features, label\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset.__getitem__","title":"__getitem__(idx)","text":"

    Return the features and label for the given index.

    Source code in src/data/polars_datamodule.py
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Return the features and label for the given index.\"\"\"\n    row = self.df[idx]\n    features = torch.tensor([val for col, val in row.items() if col != self.output_column], dtype=torch.float32)\n    label = torch.tensor(row[self.output_column], dtype=torch.long)\n    return features, label\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset.__init__","title":"__init__(df, output_column)","text":"

    Initialize the PolarsDataset.

    Source code in src/data/polars_datamodule.py
    def __init__(self, df: pl.DataFrame, output_column: str) -> None:\n    \"\"\"Initialize the PolarsDataset.\"\"\"\n    self.df = df\n    self.output_column = output_column\n
    "},{"location":"api/data/polars_datamodule/#src.data.polars_datamodule.PolarsDataset.__len__","title":"__len__()","text":"

    Return the number of rows in the dataset.

    Source code in src/data/polars_datamodule.py
    def __len__(self) -> int:\n    \"\"\"Return the number of rows in the dataset.\"\"\"\n    return self.df.shape[0]\n
    "},{"location":"api/models/mnist_module/","title":"Mnist module","text":"

    Mnist simple model.

    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule","title":"MNISTLitModule","text":"

    Bases: LightningModule

    Example of a LightningModule for MNIST classification.

    A LightningModule implements 8 key methods:

    def __init__(self):\n# Define initialization code here.\n\ndef setup(self, stage):\n# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.\n# This hook is called on every process when using DDP.\n\ndef training_step(self, batch, batch_idx):\n# The complete training step.\n\ndef validation_step(self, batch, batch_idx):\n# The complete validation step.\n\ndef test_step(self, batch, batch_idx):\n# The complete test step.\n\ndef predict_step(self, batch, batch_idx):\n# The complete predict step.\n\ndef configure_optimizers(self):\n# Define and configure optimizers and LR schedulers.\n
    Docs

    https://lightning.ai/docs/pytorch/latest/common/lightning_module.html

    Source code in src/models/mnist_module.py
    class MNISTLitModule(LightningModule):\n    \"\"\"Example of a `LightningModule` for MNIST classification.\n\n    A `LightningModule` implements 8 key methods:\n\n    ```python\n    def __init__(self):\n    # Define initialization code here.\n\n    def setup(self, stage):\n    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.\n    # This hook is called on every process when using DDP.\n\n    def training_step(self, batch, batch_idx):\n    # The complete training step.\n\n    def validation_step(self, batch, batch_idx):\n    # The complete validation step.\n\n    def test_step(self, batch, batch_idx):\n    # The complete test step.\n\n    def predict_step(self, batch, batch_idx):\n    # The complete predict step.\n\n    def configure_optimizers(self):\n    # Define and configure optimizers and LR schedulers.\n    ```\n\n    Docs:\n        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html\n    \"\"\"\n\n    def __init__(\n        self,\n        net: torch.nn.Module,\n        optimizer: torch.optim.Optimizer,\n        scheduler: torch.optim.lr_scheduler,\n        compile_model: bool,\n    ) -> None:\n        \"\"\"Initialize a `MNISTLitModule`.\n\n        Args:\n            net: The model to train.\n            optimizer: The optimizer to use for training.\n            scheduler: The learning rate scheduler to use for training.\n            compile_model: Whether or not compile the model.\n        \"\"\"\n        super().__init__()\n\n        # this line allows to access init params with 'self.hparams' attribute\n        # also ensures init params will be stored in ckpt\n        self.save_hyperparameters(logger=False)\n\n        self.net = net\n\n        # loss function\n        self.criterion = torch.nn.CrossEntropyLoss()\n\n        # metric objects for calculating and averaging accuracy across batches\n        self.train_acc = Accuracy(task=\"multiclass\", num_classes=10)\n        self.val_acc = Accuracy(task=\"multiclass\", num_classes=10)\n        self.test_acc = Accuracy(task=\"multiclass\", num_classes=10)\n\n        # for averaging loss across batches\n        self.train_loss = MeanMetric()\n        self.val_loss = MeanMetric()\n        self.test_loss = MeanMetric()\n\n        # for tracking best so far validation accuracy\n        self.val_acc_best = MaxMetric()\n\n    @typechecked\n    def forward(self, x: TensorType[\"batch\", 1, 28, 28]) -> TensorType[\"batch\", 10]:  # noqa\n        \"\"\"Perform a forward pass through the model.\n\n        Args:\n            x: A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.\n\n        Returns:\n            A tensor of shape (batch_size, 10) representing the logits for each class.\n        \"\"\"\n        return self.net(x)\n\n    def on_train_start(self) -> None:\n        \"\"\"Lightning hook that is called when training begins.\"\"\"\n        # by default lightning executes validation step sanity checks before training starts,\n        # so it's worth to make sure validation metrics don't store results from these checks\n        self.val_loss.reset()\n        self.val_acc.reset()\n        self.val_acc_best.reset()\n\n    @typechecked\n    def model_step(self, x: TensorType[\"batch\", 1, 28, 28], y: TensorType[\"batch\"]):  # noqa\n        \"\"\"Perform a single model step.\n\n        Args:\n            x: Tensor of shape [batch, 1, 28, 28] representing the images.\n            y: Tensor of shape [batch] representing the classes.\n\n        Returns:\n            A tuple containing:\n                - loss: A tensor of shape (batch_size,)\n                - preds: A tensor of predicted class indices (batch_size,)\n                - targets: A tensor of true class labels (batch_size,)\n        \"\"\"\n        logits = self.forward(x)\n        loss = self.criterion(logits, y)\n        preds = torch.argmax(logits, dim=1)\n        return loss, preds, y\n\n    @typechecked\n    def training_step(self, batch: Any) -> TensorType[()]:\n        \"\"\"Perform a single training step.\n\n        Args:\n            batch: A tuple containing input images and target labels.\n            batch_idx: The index of the current batch.\n\n        Returns:\n            A scalar loss tensor.\n        \"\"\"\n        x, y = batch\n        loss, preds, targets = self.model_step(x, y)\n        self.train_loss(loss)\n        self.train_acc(preds, targets)\n        self.log(\"train/loss\", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)\n        self.log(\"train/acc\", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)\n        return loss\n\n    def on_train_epoch_end(self) -> None:\n        \"\"\"Lightning hook that is called when a training epoch ends.\"\"\"\n        pass\n\n    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n        \"\"\"Perform a single validation step on a batch of data from the validation set.\n\n        Args:\n            batch: A batch of data (a tuple) containing the input tensor of images and target\n                labels.\n            batch_idx: The index of the current batch.\n        \"\"\"\n        x, y = batch\n        loss, preds, targets = self.model_step(x, y)\n\n        # update and log metrics\n        self.val_loss(loss)\n        self.val_acc(preds, targets)\n        self.log(\"val/loss\", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)\n        self.log(\"val/acc\", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)\n\n    def on_validation_epoch_end(self) -> None:\n        \"\"\"Lightning hook that is called when a validation epoch ends.\"\"\"\n        acc = self.val_acc.compute()  # get current val acc\n        self.val_acc_best(acc)  # update best so far val acc\n        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object\n        # otherwise metric would be reset by lightning after each epoch\n        self.log(\"val/acc_best\", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)\n\n    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n        \"\"\"Perform a single test step on a batch of data from the test set.\n\n        Args:\n            batch: A batch of data (a tuple) containing the input tensor of images and target\n                labels.\n            batch_idx: The index of the current batch.\n        \"\"\"\n        x, y = batch\n        loss, preds, targets = self.model_step(x, y)\n\n        # update and log metrics\n        self.test_loss(loss)\n        self.test_acc(preds, targets)\n        self.log(\"test/loss\", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)\n        self.log(\"test/acc\", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)\n\n    def on_test_epoch_end(self) -> None:\n        \"\"\"Lightning hook that is called when a test epoch ends.\"\"\"\n        pass\n\n    def setup(self, stage: str) -> None:\n        \"\"\"Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.\n\n        This is a good hook when you need to build models dynamically or adjust something about\n        them. This hook is called on every process when using DDP.\n\n        Args:\n            stage: Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n        \"\"\"\n        if self.hparams.compile_model and stage == \"fit\":\n            self.net = torch.compile(self.net)\n\n    def configure_optimizers(self) -> dict[str, Any]:\n        \"\"\"Choose what optimizers and learning-rate schedulers to use in your optimization.\n\n        Normally you'd need one. But in the case of GANs or similar you might have multiple.\n\n        Examples:\n            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers\n\n        Returns:\n            A dict containing the configured optimizers and learning-rate schedulers to be used for training.\n        \"\"\"\n        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())\n        if self.hparams.scheduler is not None:\n            scheduler = self.hparams.scheduler(optimizer=optimizer)\n            return {\n                \"optimizer\": optimizer,\n                \"lr_scheduler\": {\n                    \"scheduler\": scheduler,\n                    \"monitor\": \"val/loss\",\n                    \"interval\": \"epoch\",\n                    \"frequency\": 1,\n                },\n            }\n        return {\"optimizer\": optimizer}\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.__init__","title":"__init__(net, optimizer, scheduler, compile_model)","text":"

    Initialize a MNISTLitModule.

    Parameters:

    Name Type Description Default net Module

    The model to train.

    required optimizer Optimizer

    The optimizer to use for training.

    required scheduler lr_scheduler

    The learning rate scheduler to use for training.

    required compile_model bool

    Whether or not compile the model.

    required Source code in src/models/mnist_module.py
    def __init__(\n    self,\n    net: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    scheduler: torch.optim.lr_scheduler,\n    compile_model: bool,\n) -> None:\n    \"\"\"Initialize a `MNISTLitModule`.\n\n    Args:\n        net: The model to train.\n        optimizer: The optimizer to use for training.\n        scheduler: The learning rate scheduler to use for training.\n        compile_model: Whether or not compile the model.\n    \"\"\"\n    super().__init__()\n\n    # this line allows to access init params with 'self.hparams' attribute\n    # also ensures init params will be stored in ckpt\n    self.save_hyperparameters(logger=False)\n\n    self.net = net\n\n    # loss function\n    self.criterion = torch.nn.CrossEntropyLoss()\n\n    # metric objects for calculating and averaging accuracy across batches\n    self.train_acc = Accuracy(task=\"multiclass\", num_classes=10)\n    self.val_acc = Accuracy(task=\"multiclass\", num_classes=10)\n    self.test_acc = Accuracy(task=\"multiclass\", num_classes=10)\n\n    # for averaging loss across batches\n    self.train_loss = MeanMetric()\n    self.val_loss = MeanMetric()\n    self.test_loss = MeanMetric()\n\n    # for tracking best so far validation accuracy\n    self.val_acc_best = MaxMetric()\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.configure_optimizers","title":"configure_optimizers()","text":"

    Choose what optimizers and learning-rate schedulers to use in your optimization.

    Normally you'd need one. But in the case of GANs or similar you might have multiple.

    Examples:

    https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

    Returns:

    Type Description dict[str, Any]

    A dict containing the configured optimizers and learning-rate schedulers to be used for training.

    Source code in src/models/mnist_module.py
    def configure_optimizers(self) -> dict[str, Any]:\n    \"\"\"Choose what optimizers and learning-rate schedulers to use in your optimization.\n\n    Normally you'd need one. But in the case of GANs or similar you might have multiple.\n\n    Examples:\n        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers\n\n    Returns:\n        A dict containing the configured optimizers and learning-rate schedulers to be used for training.\n    \"\"\"\n    optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())\n    if self.hparams.scheduler is not None:\n        scheduler = self.hparams.scheduler(optimizer=optimizer)\n        return {\n            \"optimizer\": optimizer,\n            \"lr_scheduler\": {\n                \"scheduler\": scheduler,\n                \"monitor\": \"val/loss\",\n                \"interval\": \"epoch\",\n                \"frequency\": 1,\n            },\n        }\n    return {\"optimizer\": optimizer}\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.forward","title":"forward(x)","text":"

    Perform a forward pass through the model.

    Parameters:

    Name Type Description Default x TensorType[batch, 1, 28, 28]

    A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.

    required

    Returns:

    Type Description TensorType[batch, 10]

    A tensor of shape (batch_size, 10) representing the logits for each class.

    Source code in src/models/mnist_module.py
    @typechecked\ndef forward(self, x: TensorType[\"batch\", 1, 28, 28]) -> TensorType[\"batch\", 10]:  # noqa\n    \"\"\"Perform a forward pass through the model.\n\n    Args:\n        x: A tensor of shape (batch_size, 1, 28, 28) representing the MNIST images.\n\n    Returns:\n        A tensor of shape (batch_size, 10) representing the logits for each class.\n    \"\"\"\n    return self.net(x)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.model_step","title":"model_step(x, y)","text":"

    Perform a single model step.

    Parameters:

    Name Type Description Default x TensorType[batch, 1, 28, 28]

    Tensor of shape [batch, 1, 28, 28] representing the images.

    required y TensorType[batch]

    Tensor of shape [batch] representing the classes.

    required

    Returns:

    Type Description

    A tuple containing: - loss: A tensor of shape (batch_size,) - preds: A tensor of predicted class indices (batch_size,) - targets: A tensor of true class labels (batch_size,)

    Source code in src/models/mnist_module.py
    @typechecked\ndef model_step(self, x: TensorType[\"batch\", 1, 28, 28], y: TensorType[\"batch\"]):  # noqa\n    \"\"\"Perform a single model step.\n\n    Args:\n        x: Tensor of shape [batch, 1, 28, 28] representing the images.\n        y: Tensor of shape [batch] representing the classes.\n\n    Returns:\n        A tuple containing:\n            - loss: A tensor of shape (batch_size,)\n            - preds: A tensor of predicted class indices (batch_size,)\n            - targets: A tensor of true class labels (batch_size,)\n    \"\"\"\n    logits = self.forward(x)\n    loss = self.criterion(logits, y)\n    preds = torch.argmax(logits, dim=1)\n    return loss, preds, y\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_test_epoch_end","title":"on_test_epoch_end()","text":"

    Lightning hook that is called when a test epoch ends.

    Source code in src/models/mnist_module.py
    def on_test_epoch_end(self) -> None:\n    \"\"\"Lightning hook that is called when a test epoch ends.\"\"\"\n    pass\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_train_epoch_end","title":"on_train_epoch_end()","text":"

    Lightning hook that is called when a training epoch ends.

    Source code in src/models/mnist_module.py
    def on_train_epoch_end(self) -> None:\n    \"\"\"Lightning hook that is called when a training epoch ends.\"\"\"\n    pass\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_train_start","title":"on_train_start()","text":"

    Lightning hook that is called when training begins.

    Source code in src/models/mnist_module.py
    def on_train_start(self) -> None:\n    \"\"\"Lightning hook that is called when training begins.\"\"\"\n    # by default lightning executes validation step sanity checks before training starts,\n    # so it's worth to make sure validation metrics don't store results from these checks\n    self.val_loss.reset()\n    self.val_acc.reset()\n    self.val_acc_best.reset()\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.on_validation_epoch_end","title":"on_validation_epoch_end()","text":"

    Lightning hook that is called when a validation epoch ends.

    Source code in src/models/mnist_module.py
    def on_validation_epoch_end(self) -> None:\n    \"\"\"Lightning hook that is called when a validation epoch ends.\"\"\"\n    acc = self.val_acc.compute()  # get current val acc\n    self.val_acc_best(acc)  # update best so far val acc\n    # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object\n    # otherwise metric would be reset by lightning after each epoch\n    self.log(\"val/acc_best\", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.setup","title":"setup(stage)","text":"

    Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

    This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

    Parameters:

    Name Type Description Default stage str

    Either \"fit\", \"validate\", \"test\", or \"predict\".

    required Source code in src/models/mnist_module.py
    def setup(self, stage: str) -> None:\n    \"\"\"Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.\n\n    This is a good hook when you need to build models dynamically or adjust something about\n    them. This hook is called on every process when using DDP.\n\n    Args:\n        stage: Either `\"fit\"`, `\"validate\"`, `\"test\"`, or `\"predict\"`.\n    \"\"\"\n    if self.hparams.compile_model and stage == \"fit\":\n        self.net = torch.compile(self.net)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.test_step","title":"test_step(batch, batch_idx)","text":"

    Perform a single test step on a batch of data from the test set.

    Parameters:

    Name Type Description Default batch tuple[Tensor, Tensor]

    A batch of data (a tuple) containing the input tensor of images and target labels.

    required batch_idx int

    The index of the current batch.

    required Source code in src/models/mnist_module.py
    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n    \"\"\"Perform a single test step on a batch of data from the test set.\n\n    Args:\n        batch: A batch of data (a tuple) containing the input tensor of images and target\n            labels.\n        batch_idx: The index of the current batch.\n    \"\"\"\n    x, y = batch\n    loss, preds, targets = self.model_step(x, y)\n\n    # update and log metrics\n    self.test_loss(loss)\n    self.test_acc(preds, targets)\n    self.log(\"test/loss\", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)\n    self.log(\"test/acc\", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.training_step","title":"training_step(batch)","text":"

    Perform a single training step.

    Parameters:

    Name Type Description Default batch Any

    A tuple containing input images and target labels.

    required batch_idx

    The index of the current batch.

    required

    Returns:

    Type Description TensorType[]

    A scalar loss tensor.

    Source code in src/models/mnist_module.py
    @typechecked\ndef training_step(self, batch: Any) -> TensorType[()]:\n    \"\"\"Perform a single training step.\n\n    Args:\n        batch: A tuple containing input images and target labels.\n        batch_idx: The index of the current batch.\n\n    Returns:\n        A scalar loss tensor.\n    \"\"\"\n    x, y = batch\n    loss, preds, targets = self.model_step(x, y)\n    self.train_loss(loss)\n    self.train_acc(preds, targets)\n    self.log(\"train/loss\", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)\n    self.log(\"train/acc\", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)\n    return loss\n
    "},{"location":"api/models/mnist_module/#src.models.mnist_module.MNISTLitModule.validation_step","title":"validation_step(batch, batch_idx)","text":"

    Perform a single validation step on a batch of data from the validation set.

    Parameters:

    Name Type Description Default batch tuple[Tensor, Tensor]

    A batch of data (a tuple) containing the input tensor of images and target labels.

    required batch_idx int

    The index of the current batch.

    required Source code in src/models/mnist_module.py
    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:\n    \"\"\"Perform a single validation step on a batch of data from the validation set.\n\n    Args:\n        batch: A batch of data (a tuple) containing the input tensor of images and target\n            labels.\n        batch_idx: The index of the current batch.\n    \"\"\"\n    x, y = batch\n    loss, preds, targets = self.model_step(x, y)\n\n    # update and log metrics\n    self.val_loss(loss)\n    self.val_acc(preds, targets)\n    self.log(\"val/loss\", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)\n    self.log(\"val/acc\", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)\n
    "},{"location":"api/models/components/simple_dense_net/","title":"Simple dense net","text":"

    Simple dense neural network.

    "},{"location":"api/models/components/simple_dense_net/#src.models.components.simple_dense_net.SimpleDenseNet","title":"SimpleDenseNet","text":"

    Bases: Module

    A simple fully-connected neural net for computing predictions.

    Source code in src/models/components/simple_dense_net.py
    class SimpleDenseNet(nn.Module):\n    \"\"\"A simple fully-connected neural net for computing predictions.\"\"\"\n\n    def __init__(\n        self,\n        input_size: int = 784,\n        lin1_size: int = 256,\n        lin2_size: int = 256,\n        lin3_size: int = 256,\n        output_size: int = 10,\n    ) -> None:\n        \"\"\"Initialize a `SimpleDenseNet` module.\n\n        Args:\n            input_size: The number of input features.\n            lin1_size: The number of output features of the first linear layer.\n            lin2_size: The number of output features of the second linear layer.\n            lin3_size: The number of output features of the third linear layer.\n            output_size: The number of output features of the final linear layer.\n        \"\"\"\n        super().__init__()\n\n        self.model = nn.Sequential(\n            nn.Linear(input_size, lin1_size),\n            nn.BatchNorm1d(lin1_size),\n            nn.ReLU(),\n            nn.Linear(lin1_size, lin2_size),\n            nn.BatchNorm1d(lin2_size),\n            nn.ReLU(),\n            nn.Linear(lin2_size, lin3_size),\n            nn.BatchNorm1d(lin3_size),\n            nn.ReLU(),\n            nn.Linear(lin3_size, output_size),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Perform a single forward pass through the network.\n\n        Args:\n            x: The input tensor.\n\n        Returns:\n            A tensor of predictions.\n        \"\"\"\n        batch_size, channels, width, height = x.size()\n\n        # (batch, 1, width, height) -> (batch, 1*width*height)\n        x = x.view(batch_size, -1)\n\n        return self.model(x)\n
    "},{"location":"api/models/components/simple_dense_net/#src.models.components.simple_dense_net.SimpleDenseNet.__init__","title":"__init__(input_size=784, lin1_size=256, lin2_size=256, lin3_size=256, output_size=10)","text":"

    Initialize a SimpleDenseNet module.

    Parameters:

    Name Type Description Default input_size int

    The number of input features.

    784 lin1_size int

    The number of output features of the first linear layer.

    256 lin2_size int

    The number of output features of the second linear layer.

    256 lin3_size int

    The number of output features of the third linear layer.

    256 output_size int

    The number of output features of the final linear layer.

    10 Source code in src/models/components/simple_dense_net.py
    def __init__(\n    self,\n    input_size: int = 784,\n    lin1_size: int = 256,\n    lin2_size: int = 256,\n    lin3_size: int = 256,\n    output_size: int = 10,\n) -> None:\n    \"\"\"Initialize a `SimpleDenseNet` module.\n\n    Args:\n        input_size: The number of input features.\n        lin1_size: The number of output features of the first linear layer.\n        lin2_size: The number of output features of the second linear layer.\n        lin3_size: The number of output features of the third linear layer.\n        output_size: The number of output features of the final linear layer.\n    \"\"\"\n    super().__init__()\n\n    self.model = nn.Sequential(\n        nn.Linear(input_size, lin1_size),\n        nn.BatchNorm1d(lin1_size),\n        nn.ReLU(),\n        nn.Linear(lin1_size, lin2_size),\n        nn.BatchNorm1d(lin2_size),\n        nn.ReLU(),\n        nn.Linear(lin2_size, lin3_size),\n        nn.BatchNorm1d(lin3_size),\n        nn.ReLU(),\n        nn.Linear(lin3_size, output_size),\n    )\n
    "},{"location":"api/models/components/simple_dense_net/#src.models.components.simple_dense_net.SimpleDenseNet.forward","title":"forward(x)","text":"

    Perform a single forward pass through the network.

    Parameters:

    Name Type Description Default x Tensor

    The input tensor.

    required

    Returns:

    Type Description Tensor

    A tensor of predictions.

    Source code in src/models/components/simple_dense_net.py
    def forward(self, x: torch.Tensor) -> torch.Tensor:\n    \"\"\"Perform a single forward pass through the network.\n\n    Args:\n        x: The input tensor.\n\n    Returns:\n        A tensor of predictions.\n    \"\"\"\n    batch_size, channels, width, height = x.size()\n\n    # (batch, 1, width, height) -> (batch, 1*width*height)\n    x = x.view(batch_size, -1)\n\n    return self.model(x)\n
    "},{"location":"api/serve_apis/mnist_serve/","title":"Mnist serve","text":"

    This is an example of a LitServe api for the Mnist LightningModule.

    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI","title":"MNISTServeAPI","text":"

    Bases: LitAPI

    LitServe API for serving the MNIST model.

    Source code in src/serve_apis/mnist_serve.py
    class MNISTServeAPI(ls.LitAPI):\n    \"\"\"LitServe API for serving the MNIST model.\"\"\"\n\n    def __init__(self, model_class: lightning.pytorch.LightningModule, checkpoint_path: str):\n        \"\"\"Initialize the MNISTServeAPI.\n\n        Args:\n            model_class: The LightningModule class to serve.\n            checkpoint_path: The path to the model checkpoint.\n        \"\"\"\n        self.checkpoint_path = checkpoint_path\n        self.model_class = model_class\n\n    def setup(self, device: str):\n        \"\"\"Setup is called once at startup.\n\n        Load the model, set the device, and prepare any other necessary components.\n        \"\"\"\n        # Load the trained MNIST model (ensure model weights are loaded properly here)\n        self.model = self.model_class.load_from_checkpoint(self.checkpoint_path)\n        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        self.model.to(device)  # Move the model to the appropriate device (CPU or GPU)\n        self.model.eval()  # Set the model to evaluation mode\n\n        # Define transforms that match the training data processing pipeline\n        self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n\n    def decode_request(self, request: dict):\n        \"\"\"Decode the incoming request and prepare the input for the model.\"\"\"\n        # Convert the request payload into a tensor for model input\n        image_data = request[\"image\"]\n        # Ensure that the image is a tensor of shape [1, 28, 28] (MNIST image dimensions)\n        image_tensor = torch.tensor(image_data).unsqueeze(0)  # Add a batch dimension\n        return self.transforms(image_tensor)  # Apply the necessary transformations\n\n    def predict(self, x: torch.Tensor):\n        \"\"\"Run inference using the MNIST model and return the prediction.\"\"\"\n        # Forward pass through the model\n        with torch.no_grad():\n            logits = self.model(x.unsqueeze(0))  # Add batch dimension for inference\n            preds = torch.argmax(logits, dim=1)  # Get the predicted class\n        return {\"prediction\": preds.item()}  # Return the prediction as a dictionary\n\n    def encode_response(self, output: dict):\n        \"\"\"Encode the model's output into a response payload.\"\"\"\n        # Simply pass the output as the response\n        return {\"predicted_class\": output[\"prediction\"]}\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.__init__","title":"__init__(model_class, checkpoint_path)","text":"

    Initialize the MNISTServeAPI.

    Parameters:

    Name Type Description Default model_class LightningModule

    The LightningModule class to serve.

    required checkpoint_path str

    The path to the model checkpoint.

    required Source code in src/serve_apis/mnist_serve.py
    def __init__(self, model_class: lightning.pytorch.LightningModule, checkpoint_path: str):\n    \"\"\"Initialize the MNISTServeAPI.\n\n    Args:\n        model_class: The LightningModule class to serve.\n        checkpoint_path: The path to the model checkpoint.\n    \"\"\"\n    self.checkpoint_path = checkpoint_path\n    self.model_class = model_class\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.decode_request","title":"decode_request(request)","text":"

    Decode the incoming request and prepare the input for the model.

    Source code in src/serve_apis/mnist_serve.py
    def decode_request(self, request: dict):\n    \"\"\"Decode the incoming request and prepare the input for the model.\"\"\"\n    # Convert the request payload into a tensor for model input\n    image_data = request[\"image\"]\n    # Ensure that the image is a tensor of shape [1, 28, 28] (MNIST image dimensions)\n    image_tensor = torch.tensor(image_data).unsqueeze(0)  # Add a batch dimension\n    return self.transforms(image_tensor)  # Apply the necessary transformations\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.encode_response","title":"encode_response(output)","text":"

    Encode the model's output into a response payload.

    Source code in src/serve_apis/mnist_serve.py
    def encode_response(self, output: dict):\n    \"\"\"Encode the model's output into a response payload.\"\"\"\n    # Simply pass the output as the response\n    return {\"predicted_class\": output[\"prediction\"]}\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.predict","title":"predict(x)","text":"

    Run inference using the MNIST model and return the prediction.

    Source code in src/serve_apis/mnist_serve.py
    def predict(self, x: torch.Tensor):\n    \"\"\"Run inference using the MNIST model and return the prediction.\"\"\"\n    # Forward pass through the model\n    with torch.no_grad():\n        logits = self.model(x.unsqueeze(0))  # Add batch dimension for inference\n        preds = torch.argmax(logits, dim=1)  # Get the predicted class\n    return {\"prediction\": preds.item()}  # Return the prediction as a dictionary\n
    "},{"location":"api/serve_apis/mnist_serve/#src.serve_apis.mnist_serve.MNISTServeAPI.setup","title":"setup(device)","text":"

    Setup is called once at startup.

    Load the model, set the device, and prepare any other necessary components.

    Source code in src/serve_apis/mnist_serve.py
    def setup(self, device: str):\n    \"\"\"Setup is called once at startup.\n\n    Load the model, set the device, and prepare any other necessary components.\n    \"\"\"\n    # Load the trained MNIST model (ensure model weights are loaded properly here)\n    self.model = self.model_class.load_from_checkpoint(self.checkpoint_path)\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    self.model.to(device)  # Move the model to the appropriate device (CPU or GPU)\n    self.model.eval()  # Set the model to evaluation mode\n\n    # Define transforms that match the training data processing pipeline\n    self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n
    "},{"location":"api/utils/download_utils/","title":"Download utils","text":"

    Utility functions aimed at downloading any data from external sources.

    "},{"location":"api/utils/download_utils/#src.utils.download_utils.download_cloud_directory","title":"download_cloud_directory(cloud_directory, output_folder, cloud='gs')","text":"

    Download a given cloud directory.

    Parameters:

    Name Type Description Default cloud_directory str

    for example gs://bucket-name/path/to/directory

    required output_folder str

    where the data downloaded will be stored (ideally data/ folder)

    required cloud str

    the cloud provider, currently only \"gs\" is supported

    'gs' Source code in src/utils/download_utils.py
    def download_cloud_directory(cloud_directory: str, output_folder: str, cloud: str = \"gs\") -> None:\n    \"\"\"Download a given cloud directory.\n\n    Args:\n        cloud_directory: for example gs://bucket-name/path/to/directory\n        output_folder: where the data downloaded will be stored (ideally data/ folder)\n        cloud: the cloud provider, currently only \"gs\" is supported\n    \"\"\"\n    cloudpathlib.Path(cloud_directory).download_to(output_folder)\n
    "},{"location":"api/utils/download_utils/#src.utils.download_utils.download_kaggle_dataset","title":"download_kaggle_dataset(dataset_name, output_folder)","text":"

    Download a given Kaggle dataset.

    Parameters:

    Name Type Description Default dataset_name str

    for example googleai/pfam-seed-random-split

    required output_folder str

    where the data downloaded will be stored (ideally data/ folder)

    required Source code in src/utils/download_utils.py
    def download_kaggle_dataset(dataset_name: str, output_folder: str) -> None:\n    \"\"\"Download a given Kaggle dataset.\n\n    Args:\n        dataset_name: for example googleai/pfam-seed-random-split\n        output_folder: where the data downloaded will be stored (ideally data/ folder)\n    \"\"\"\n    from kaggle.api.kaggle_api_extended import KaggleApi\n\n    api = KaggleApi()\n    log.info(\"Authenticating to Kaggle API\")\n    api.authenticate()\n    log.info(\"Downloading dataset\")\n    api.dataset_download_files(dataset_name, path=output_folder, unzip=True, quiet=False)\n    log.info(\"Download successful\")\n
    "},{"location":"api/utils/instantiators/","title":"Instantiators","text":"

    Module to instantiate different objects types.

    "},{"location":"api/utils/instantiators/#src.utils.instantiators.instantiate_callbacks","title":"instantiate_callbacks(callbacks_cfg)","text":"

    Instantiates callbacks from config.

    :param callbacks_cfg: A DictConfig object containing callback configurations. :return: A list of instantiated callbacks.

    Source code in src/utils/instantiators.py
    def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]:\n    \"\"\"Instantiates callbacks from config.\n\n    :param callbacks_cfg: A DictConfig object containing callback configurations.\n    :return: A list of instantiated callbacks.\n    \"\"\"\n    callbacks: list[Callback] = []\n\n    if not callbacks_cfg:\n        log.warning(\"No callback configs found! Skipping..\")\n        return callbacks\n\n    if not isinstance(callbacks_cfg, DictConfig):\n        raise TypeError(\"Callbacks config must be a DictConfig!\")  # noqa: TRY003\n\n    for _, cb_conf in callbacks_cfg.items():\n        if isinstance(cb_conf, DictConfig) and \"_target_\" in cb_conf:\n            log.info(f\"Instantiating callback <{cb_conf._target_}>\")\n            callbacks.append(hydra.utils.instantiate(cb_conf))\n\n    return callbacks\n
    "},{"location":"api/utils/instantiators/#src.utils.instantiators.instantiate_loggers","title":"instantiate_loggers(logger_cfg)","text":"

    Instantiates loggers from config.

    :param logger_cfg: A DictConfig object containing logger configurations. :return: A list of instantiated loggers.

    Source code in src/utils/instantiators.py
    def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]:\n    \"\"\"Instantiates loggers from config.\n\n    :param logger_cfg: A DictConfig object containing logger configurations.\n    :return: A list of instantiated loggers.\n    \"\"\"\n    logger: list[Logger] = []\n\n    if not logger_cfg:\n        log.warning(\"No logger configs found! Skipping...\")\n        return logger\n\n    if not isinstance(logger_cfg, DictConfig):\n        raise TypeError(\"Logger config must be a DictConfig!\")  # noqa: TRY003\n\n    for _, lg_conf in logger_cfg.items():\n        if isinstance(lg_conf, DictConfig) and \"_target_\" in lg_conf:\n            log.info(f\"Instantiating logger <{lg_conf._target_}>\")\n            logger.append(hydra.utils.instantiate(lg_conf))\n\n    return logger\n
    "},{"location":"api/utils/logging_utils/","title":"Logging utils","text":"

    Logging utility instantiator.

    "},{"location":"api/utils/logging_utils/#src.utils.logging_utils.log_hyperparameters","title":"log_hyperparameters(object_dict)","text":"

    Controls which config parts are saved by Lightning loggers.

    Additionally saves number of model parameters

    Parameters:

    Name Type Description Default object_dict dict[str, Any]

    A dictionary containing the following objects: cfg, model, trainer.

    required Source code in src/utils/logging_utils.py
    @rank_zero_only\ndef log_hyperparameters(object_dict: dict[str, Any]) -> None:\n    \"\"\"Controls which config parts are saved by Lightning loggers.\n\n    Additionally saves number of model parameters\n\n    Args:\n        object_dict: A dictionary containing the following objects: cfg, model, trainer.\n    \"\"\"\n    hparams = {}\n\n    cfg = OmegaConf.to_container(object_dict[\"cfg\"])\n    model = object_dict[\"model\"]\n    trainer = object_dict[\"trainer\"]\n\n    if not trainer.logger:\n        log.warning(\"Logger not found! Skipping hyperparameter logging...\")\n        return\n\n    hparams[\"model\"] = cfg[\"model\"]\n\n    # save number of model parameters\n    hparams[\"model/params/total\"] = sum(p.numel() for p in model.parameters())\n    hparams[\"model/params/trainable\"] = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    hparams[\"model/params/non_trainable\"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)\n\n    hparams[\"data\"] = cfg[\"data\"]\n    hparams[\"trainer\"] = cfg[\"trainer\"]\n\n    hparams[\"callbacks\"] = cfg.get(\"callbacks\")\n    hparams[\"extras\"] = cfg.get(\"extras\")\n\n    hparams[\"task_name\"] = cfg.get(\"task_name\")\n    hparams[\"tags\"] = cfg.get(\"tags\")\n    hparams[\"ckpt_path\"] = cfg.get(\"ckpt_path\")\n    hparams[\"seed\"] = cfg.get(\"seed\")\n    hparams[\"execution_command\"] = f\"python {' '.join(sys.argv)}\"\n\n    # send hparams to all loggers\n    for logger in trainer.loggers:\n        logger.log_hyperparams(hparams)\n
    "},{"location":"api/utils/pylogger/","title":"Pylogger","text":"

    Code for logging on multi-GPU-friendly.

    "},{"location":"api/utils/pylogger/#src.utils.pylogger.RankedLogger","title":"RankedLogger","text":"

    Bases: LoggerAdapter

    A multi-GPU-friendly python command line logger.

    Source code in src/utils/pylogger.py
    class RankedLogger(logging.LoggerAdapter):\n    \"\"\"A multi-GPU-friendly python command line logger.\"\"\"\n\n    def __init__(\n        self,\n        name: str = __name__,\n        rank_zero_only: bool = False,\n        extra: Mapping[str, object] | None = None,\n    ) -> None:\n        \"\"\"Initializes a multi-GPU-friendly python command line logger that logs.\n\n        On all processes with their rank prefixed in the log message.\n\n        Args:\n            name: The name of the logger. Default is ``__name__``.\n            rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.\n            extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.\n        \"\"\"\n        logger = logging.getLogger(name)\n        super().__init__(logger=logger, extra=extra)\n        self.rank_zero_only = rank_zero_only\n\n    def log(self, level: int, msg: str, rank: int | None = None, *args, **kwargs) -> None:  # type: ignore\n        \"\"\"Delegate a log call to the underlying logger.\n\n        After prefixing its message with the rank\n        of the process it's being logged from. If `'rank'` is provided, then the log will only\n        occur on that rank/process.\n\n        Args:\n            level: The level to log at. Look at `logging.__init__.py` for more information.\n            msg: The message to log.\n            rank: The rank to log at.\n            args: Additional args to pass to the underlying logging function.\n            kwargs: Any additional keyword args to pass to the underlying logging function.\n        \"\"\"\n        if self.isEnabledFor(level):\n            msg, kwargs = self.process(msg, kwargs)  # type: ignore\n            current_rank = getattr(rank_zero_only, \"rank\", None)\n            if current_rank is None:\n                raise RuntimeError(\"The `rank_zero_only.rank` needs to be set before use\")  # noqa\n            msg = rank_prefixed_message(msg, current_rank)\n            if self.rank_zero_only:\n                if current_rank == 0:\n                    self.logger.log(level, msg, *args, **kwargs)\n            else:\n                if rank is None or current_rank == rank:\n                    self.logger.log(level, msg, *args, **kwargs)\n
    "},{"location":"api/utils/pylogger/#src.utils.pylogger.RankedLogger.__init__","title":"__init__(name=__name__, rank_zero_only=False, extra=None)","text":"

    Initializes a multi-GPU-friendly python command line logger that logs.

    On all processes with their rank prefixed in the log message.

    Parameters:

    Name Type Description Default name str

    The name of the logger. Default is __name__.

    __name__ rank_zero_only bool

    Whether to force all logs to only occur on the rank zero process. Default is False.

    False extra Mapping[str, object] | None

    (Optional) A dict-like object which provides contextual information. See logging.LoggerAdapter.

    None Source code in src/utils/pylogger.py
    def __init__(\n    self,\n    name: str = __name__,\n    rank_zero_only: bool = False,\n    extra: Mapping[str, object] | None = None,\n) -> None:\n    \"\"\"Initializes a multi-GPU-friendly python command line logger that logs.\n\n    On all processes with their rank prefixed in the log message.\n\n    Args:\n        name: The name of the logger. Default is ``__name__``.\n        rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.\n        extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.\n    \"\"\"\n    logger = logging.getLogger(name)\n    super().__init__(logger=logger, extra=extra)\n    self.rank_zero_only = rank_zero_only\n
    "},{"location":"api/utils/pylogger/#src.utils.pylogger.RankedLogger.log","title":"log(level, msg, rank=None, *args, **kwargs)","text":"

    Delegate a log call to the underlying logger.

    After prefixing its message with the rank of the process it's being logged from. If 'rank' is provided, then the log will only occur on that rank/process.

    Parameters:

    Name Type Description Default level int

    The level to log at. Look at logging.__init__.py for more information.

    required msg str

    The message to log.

    required rank int | None

    The rank to log at.

    None args

    Additional args to pass to the underlying logging function.

    () kwargs

    Any additional keyword args to pass to the underlying logging function.

    {} Source code in src/utils/pylogger.py
    def log(self, level: int, msg: str, rank: int | None = None, *args, **kwargs) -> None:  # type: ignore\n    \"\"\"Delegate a log call to the underlying logger.\n\n    After prefixing its message with the rank\n    of the process it's being logged from. If `'rank'` is provided, then the log will only\n    occur on that rank/process.\n\n    Args:\n        level: The level to log at. Look at `logging.__init__.py` for more information.\n        msg: The message to log.\n        rank: The rank to log at.\n        args: Additional args to pass to the underlying logging function.\n        kwargs: Any additional keyword args to pass to the underlying logging function.\n    \"\"\"\n    if self.isEnabledFor(level):\n        msg, kwargs = self.process(msg, kwargs)  # type: ignore\n        current_rank = getattr(rank_zero_only, \"rank\", None)\n        if current_rank is None:\n            raise RuntimeError(\"The `rank_zero_only.rank` needs to be set before use\")  # noqa\n        msg = rank_prefixed_message(msg, current_rank)\n        if self.rank_zero_only:\n            if current_rank == 0:\n                self.logger.log(level, msg, *args, **kwargs)\n        else:\n            if rank is None or current_rank == rank:\n                self.logger.log(level, msg, *args, **kwargs)\n
    "},{"location":"api/utils/rich_utils/","title":"Rich utils","text":"

    Rich utils to print config tree.

    "},{"location":"api/utils/rich_utils/#src.utils.rich_utils.enforce_tags","title":"enforce_tags(cfg, save_to_file=False)","text":"

    Prompts user to input tags from command line if no tags are provided in config.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig composed by Hydra.

    required save_to_file bool

    Whether to export tags to the hydra output folder. Default is False.

    False Source code in src/utils/rich_utils.py
    @rank_zero_only\ndef enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:\n    \"\"\"Prompts user to input tags from command line if no tags are provided in config.\n\n    Args:\n        cfg: A DictConfig composed by Hydra.\n        save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.\n    \"\"\"\n    if not cfg.get(\"tags\"):\n        if \"id\" in HydraConfig().cfg.hydra.job:\n            raise ValueError(\"Specify tags before launching a multirun!\")  # noqa\n\n        log.warning(\"No tags provided in config. Prompting user to input tags...\")\n        tags = Prompt.ask(\"Enter a list of comma separated tags\", default=\"dev\")\n        tags = [t.strip() for t in tags.split(\",\") if t != \"\"]\n\n        with open_dict(cfg):\n            cfg.tags = tags\n\n        log.info(f\"Tags: {cfg.tags}\")\n\n    if save_to_file:\n        with open(Path(cfg.paths.output_dir, \"tags.log\"), \"w\") as file:\n            rich.print(cfg.tags, file=file)\n
    "},{"location":"api/utils/rich_utils/#src.utils.rich_utils.print_config_tree","title":"print_config_tree(cfg, print_order=('data', 'model', 'callbacks', 'logger', 'trainer', 'paths', 'extras'), resolve=False, save_to_file=False)","text":"

    Prints the contents of a DictConfig as a tree structure using the Rich library.

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig composed by Hydra.

    required print_order Sequence[str]

    Determines in what order config components are printed. Default is ``(\"data\", \"model\",

    ('data', 'model', 'callbacks', 'logger', 'trainer', 'paths', 'extras') resolve bool

    Whether to resolve reference fields of DictConfig. Default is False.

    False save_to_file bool

    Whether to export config to the hydra output folder. Default is False.

    False Source code in src/utils/rich_utils.py
    @rank_zero_only\ndef print_config_tree(\n    cfg: DictConfig,\n    print_order: Sequence[str] = (\n        \"data\",\n        \"model\",\n        \"callbacks\",\n        \"logger\",\n        \"trainer\",\n        \"paths\",\n        \"extras\",\n    ),\n    resolve: bool = False,\n    save_to_file: bool = False,\n) -> None:\n    \"\"\"Prints the contents of a DictConfig as a tree structure using the Rich library.\n\n    Args:\n        cfg: A DictConfig composed by Hydra.\n        print_order: Determines in what order config components are printed. Default is ``(\"data\", \"model\",\n        \"callbacks\", \"logger\", \"trainer\", \"paths\", \"extras\")``.\n        resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.\n        save_to_file: Whether to export config to the hydra output folder. Default is ``False``.\n    \"\"\"\n    style = \"dim\"\n    tree = rich.tree.Tree(\"CONFIG\", style=style, guide_style=style)\n\n    queue = []\n\n    # add fields from `print_order` to queue\n    for field in print_order:\n        queue.append(field) if field in cfg else log.warning(\n            f\"Field '{field}' not found in config. Skipping '{field}' config printing...\"\n        )\n\n    # add all the other fields to queue (not specified in `print_order`)\n    for field in cfg:\n        if field not in queue:\n            queue.append(field)\n\n    # generate config tree from queue\n    for field in queue:\n        branch = tree.add(field, style=style, guide_style=style)\n\n        config_group = cfg[field]\n        if isinstance(config_group, DictConfig):\n            branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)\n        else:\n            branch_content = str(config_group)\n\n        branch.add(rich.syntax.Syntax(branch_content, \"yaml\"))\n\n    # print config tree\n    rich.print(tree)\n\n    # save config tree to file\n    if save_to_file:\n        with open(Path(cfg.paths.output_dir, \"config_tree.log\"), \"w\") as file:\n            rich.print(tree, file=file)\n
    "},{"location":"api/utils/utils/","title":"Utils","text":"

    Utility functions for various tasks.

    "},{"location":"api/utils/utils/#src.utils.utils.extras","title":"extras(cfg)","text":"

    Applies optional utilities before the task is started.

    Utilities

    Parameters:

    Name Type Description Default cfg DictConfig

    A DictConfig object containing the config tree.

    required Source code in src/utils/utils.py
    def extras(cfg: DictConfig) -> None:\n    \"\"\"Applies optional utilities before the task is started.\n\n    Utilities:\n        - Ignoring python warnings\n        - Setting tags from command line\n        - Rich config printing\n\n    Args:\n        cfg: A DictConfig object containing the config tree.\n    \"\"\"\n    # return if no `extras` config\n    if not cfg.get(\"extras\"):\n        log.warning(\"Extras config not found! <cfg.extras=null>\")\n        return\n\n    # disable python warnings\n    if cfg.extras.get(\"ignore_warnings\"):\n        log.info(\"Disabling python warnings! <cfg.extras.ignore_warnings=True>\")\n        warnings.filterwarnings(\"ignore\")\n\n    # prompt user to input tags from command line if none are provided in the config\n    if cfg.extras.get(\"enforce_tags\"):\n        log.info(\"Enforcing tags! <cfg.extras.enforce_tags=True>\")\n        rich_utils.enforce_tags(cfg, save_to_file=True)\n\n    # pretty print config tree using Rich library\n    if cfg.extras.get(\"print_config\"):\n        log.info(\"Printing config tree with Rich! <cfg.extras.print_config=True>\")\n        rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)\n
    "},{"location":"api/utils/utils/#src.utils.utils.fetch_data","title":"fetch_data(url)","text":"

    Fetches data from a URL.

    Source code in src/utils/utils.py
    def fetch_data(url):\n    \"\"\"Fetches data from a URL.\"\"\"\n    response = requests.get(url)\n    if response.status_code == 200:\n        return response.json()\n    return None\n
    "},{"location":"api/utils/utils/#src.utils.utils.file_lock","title":"file_lock(filename, mode='r')","text":"

    This context manager is used to acquire a file lock on a file.

    particularly useful for shared resources in multi-process environments (multi GPU/TPU training).

    Parameters:

    Name Type Description Default filename Path

    Path to the file to lock

    required mode str

    The mode to open the file with, either \"r\" or \"w\"

    'r'

    Raises:

    Type Description ValueError

    If the mode is invalid (neither \"r\" nor \"w\")

    Source code in src/utils/utils.py
    @contextlib.contextmanager\ndef file_lock(filename: Path, mode: str = \"r\") -> Any:\n    \"\"\"This context manager is used to acquire a file lock on a file.\n\n    particularly useful for shared resources in multi-process environments (multi GPU/TPU training).\n\n    Args:\n        filename: Path to the file to lock\n        mode: The mode to open the file with, either \"r\" or \"w\"\n\n    Raises:\n        ValueError: If the mode is invalid (neither \"r\" nor \"w\")\n    \"\"\"\n    with open(filename, mode) as f:\n        try:\n            match mode:\n                case \"r\":\n                    fcntl.flock(f.fileno(), fcntl.LOCK_SH)\n                case \"w\":\n                    fcntl.flock(f.fileno(), fcntl.LOCK_EX)\n                case _:\n                    raise ValueError(\"Expected mode 'r' or 'w'.\")  # noqa\n            yield f\n        finally:\n            fcntl.flock(f.fileno(), fcntl.LOCK_UN)\n
    "},{"location":"api/utils/utils/#src.utils.utils.file_lock_operation","title":"file_lock_operation(file_name, operation)","text":"

    This function is used to perform an operation on a file while acquiring a lock on it.

    The lock is acquired using the file_lock context manager, and based on a file stored in a temporary folder

    Parameters:

    Name Type Description Default file_name str

    Path to the file to lock

    required operation Callable

    The operation to perform on the file

    required

    Returns:

    Type Description Any

    The result of the operation

    Source code in src/utils/utils.py
    @contextlib.contextmanager\ndef file_lock_operation(file_name: str, operation: Callable) -> Any:\n    \"\"\"This function is used to perform an operation on a file while acquiring a lock on it.\n\n    The lock is acquired using the `file_lock` context manager, and based on a file stored in a temporary folder\n\n    Args:\n        file_name: Path to the file to lock\n        operation: The operation to perform on the file\n\n    Returns:\n        The result of the operation\n    \"\"\"\n    with tempfile.TemporaryDirectory() as temp_dir:\n        file_path = Path(temp_dir) / file_name\n        with file_lock(file_path, mode=\"w\"):\n            result = operation(file_path)\n        return result\n
    "},{"location":"api/utils/utils/#src.utils.utils.get_metric_value","title":"get_metric_value(metric_dict, metric_name)","text":"

    Safely retrieves value of the metric logged in LightningModule.

    Parameters:

    Name Type Description Default metric_dict dict[str, Any]

    A dict containing metric values.

    required metric_name str | None

    If provided, the name of the metric to retrieve.

    required

    Returns:

    Type Description None | float

    If a metric name was provided, the value of the metric.

    Source code in src/utils/utils.py
    def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> None | float:\n    \"\"\"Safely retrieves value of the metric logged in LightningModule.\n\n    Args:\n        metric_dict: A dict containing metric values.\n        metric_name: If provided, the name of the metric to retrieve.\n\n    Returns:\n        If a metric name was provided, the value of the metric.\n    \"\"\"\n    if not metric_name:\n        log.info(\"Metric name is None! Skipping metric value retrieval...\")\n        return None\n\n    if metric_name not in metric_dict:\n        raise ValueError(f\"Metric value not found! <metric_name={metric_name}>\\n\")  # noqa: TRY003\n\n    metric_value = metric_dict[metric_name].item()\n    log.info(f\"Retrieved metric value! <{metric_name}={metric_value}>\")\n\n    return metric_value\n
    "},{"location":"api/utils/utils/#src.utils.utils.process_data","title":"process_data(url)","text":"

    Fetches data from a URL and processes it.

    Source code in src/utils/utils.py
    def process_data(url):\n    \"\"\"Fetches data from a URL and processes it.\"\"\"\n    data = fetch_data(url)\n    if data:\n        return len(data)  # Just an example of processing, counting data length\n    return 0\n
    "},{"location":"api/utils/utils/#src.utils.utils.task_wrapper","title":"task_wrapper(task_func)","text":"

    Optional decorator that controls the failure behavior when executing the task function.

    This wrapper can be used to

    Example:

    @utils.task_wrapper\ndef train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n    ...\n    return metric_dict, object_dict\n

    Parameters:

    Name Type Description Default task_func Callable

    The task function to be wrapped.

    required

    Returns:

    Type Description Callable

    The wrapped task function.

    Source code in src/utils/utils.py
    def task_wrapper(task_func: Callable) -> Callable:\n    \"\"\"Optional decorator that controls the failure behavior when executing the task function.\n\n    This wrapper can be used to:\n        - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)\n        - save the exception to a `.log` file\n        - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)\n        - etc. (adjust depending on your needs)\n\n    Example:\n    ```\n    @utils.task_wrapper\n    def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n        ...\n        return metric_dict, object_dict\n    ```\n\n    Args:\n        task_func: The task function to be wrapped.\n\n    Returns:\n        The wrapped task function.\n    \"\"\"\n\n    def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:\n        # execute the task\n        try:\n            metric_dict, object_dict = task_func(cfg=cfg)\n\n        # things to do if exception occurs\n        except Exception as e:\n            # save exception to `.log` file\n            log.exception(\"\")\n\n            # some hyperparameter combinations might be invalid or cause out-of-memory errors\n            # so when using hparam search plugins like Optuna, you might want to disable\n            # raising the below exception to avoid multirun failure\n            raise e  # noqa: TRY201\n\n        # things to always do after either success or exception\n        finally:\n            # display output dir path in terminal\n            log.info(f\"Output dir: {cfg.paths.output_dir}\")\n\n            # always close wandb run (even if exception occurs so multirun won't fail)\n            if find_spec(\"wandb\"):  # check if wandb is installed\n                import wandb\n\n                if wandb.run:\n                    log.info(\"Closing wandb!\")\n                    wandb.finish()\n\n        return metric_dict, object_dict\n\n    return wrap\n
    "}]} \ No newline at end of file