This repository provides a robust and flexible framework for training image classification models using PyTorch. It's designed to be highly customizable and easy to use, allowing you to run experiments with different models, data augmentation techniques, and training configurations.
├── dataset
│ └── README.md
├── training_eng
│ ├── core
│ │ ├── device.py
│ │ ├── misc.py
│ │ └── callback_arg.py
│ ├── train_utils
│ │ ├── early_stopping.py
│ │ └── model_eval.py
│ ├── data_utils
│ │ ├── data_proc.py
│ │ └── data_loader.py
│ └── trainer.py
├── train_exper.py
├── tensorboard.cmd
├── logs
├── cache
├── uv.lock
├── pyproject.toml
├── tensorboard.sh
├── run_expers.py
├── GIT_COMMIT.md
├── models
└── expers.toml
- Experiment Management: Easily define and run multiple experiments using a simple TOML configuration file (
expers.toml). - Data Loading and Processing: Efficient data loading and augmentation pipelines with support for various backends (
opencv,pil,turbojpeg). - Flexible Training Loop: The core training loop in
training_eng/trainer.pysupports:- Mixed precision training
- Gradient accumulation
- Learning rate schedulers
- Early stopping
- TensorBoard logging
- Model compilation with
torch.compile
- Extensible Model Support: Easily integrate any PyTorch model. The current example uses
efficientnet-pytorch. - Rich Console Output: Uses the
richlibrary for beautiful and informative console output.
- Python 3.11+
- PyTorch
- Other dependencies listed in
pyproject.toml
-
Clone the repository:
git clone https://github.com/AidinHamedi/Pytorch-Img-Classification-Trainer-V2.git cd Pytorch-Img-Classification-Trainer-V2 -
Install dependencies: This project uses
uvfor package management.pip install uv uv sync
If you want to use turbojpeg
uv sync --extra tjpeg
Place your training and validation datasets in the dataset/train and dataset/validation directories, respectively. The data should be organized in subdirectories, where each subdirectory represents a class.
dataset/
├── train/
│ ├── class_a/
│ │ ├── image1.jpg
│ │ └── image2.jpg
│ └── class_b/
│ ├── image3.jpg
│ └── image4.jpg
└── validation/
├── class_a/
│ ├── image5.jpg
│ └── image6.jpg
└── class_b/
├── image7.jpg
└── image8.jpg
-
Define your experiments in
expers.toml:Each section in
expers.tomlrepresents a separate experiment. You can specify the model name and other parameters for each experiment.Example
expers.toml:["Test"] model_name = "efficientnet-b0" ["Experiment_2"] model_name = "efficientnet-b1"
-
Configure training parameters in
train_exper.py:This file contains the main configuration for the training process, including:
- Dataset paths
- Image resolution
- Batch size
- Data augmentation settings
- Optimizer and loss function
- And other training-related hyperparameters.
-
Run the experiments:
Execute the
run_expers.pyscript to start training all the experiments defined inexpers.toml.python run_expers.py
The script will iterate through each experiment, train the model, and save the results.
- TensorBoard: Monitor the training process in real-time using TensorBoard.
- On Windows, run
tensorboard.cmd. - On Linux/macOS, run
tensorboard.sh.
- On Windows, run
- Saved Models: The best and latest models for each experiment are saved in the
modelsdirectory. - Logs: Training logs are stored in the
logsdirectory.
run_expers.py: This is the main entry point. It reads theexpers.tomlfile and iterates through each experiment defined in it.train_exper.py: For each experiment, this script sets up the data loaders, model, optimizer, and loss function based on the configuration. It then calls thefitfunction fromtraining_eng/trainer.py. (can be modified to suit your needs)training_eng/trainer.py: This file contains the corefitfunction that implements the training loop. It handles all the complexities of training, including mixed precision, gradient accumulation, early stopping, and logging.training_eng/data_utils: These modules handle the creation of data pairs, data loading, and data augmentation.training_eng/train_utils: These modules provide utilities for model evaluation and early stopping.training_eng/core: These modules provide core functionalities like device management and callback arguments.
Contributions are welcome! Please feel free to submit a pull request or open an issue.
Copyright (c) 2025 Aidin Hamedi This software is released under the MIT License. https://opensource.org/licenses/MIT
