This repository contains an implementation of AlexNet using PyTorch for training, testing, and inference on the CIFAR-10 dataset.
ALEXNET_PYTORCH/
│── test_images/ # Directory for test images
│── src/ # Source code directory
│ │── __init__.py # Init file
│ │── dataset.py # Data loading utilities
│ │── inference.py # Model inference script
│ │── model.py # AlexNet model implementation
│ │── test.py # Script for testing the trained model
│ │── train.py # Training script
|── LICENSE # License description
|── README.md # Readme file
│── requirements.txt # Required dependencies
Ensure you have the following dependencies installed before running the scripts:
pip install -r requirements.txt
To train AlexNet on CIFAR-10, run the following command:
python src/train.py
The training script uses the following default parameters:
-
Number of Classes: 10 (CIFAR-10 dataset)
-
Epochs: 20
-
Batch Size: 64
-
Learning Rate: 0.001
The best model based on validation accuracy will be saved as best_model.pth.
-
Load the CIFAR-10 dataset and split it into training and validation sets.
-
Initialize the AlexNet model and configure it to run on a GPU (if available).
-
Define the loss function (
CrossEntropyLoss
) and the optimizer (SGD
). -
Train the model for a specified number of epochs, calculating loss and updating weights.
-
Evaluate the model on the validation set after each epoch.
-
Save the model if the validation accuracy improves.
Once trained, you can evaluate the model using:
python src/test.py
To perform inference on new images, use the inference.py
script:
python src/inference.py --image_path test_images/sample.jpg
The model.py script contains the AlexNet architecture defined using PyTorch. The model is adapted for CIFAR-10 by modifying the fully connected layers to match the dataset's 10 output classes.
The dataset.py script includes utilities for loading and preprocessing the CIFAR-10 dataset, including splitting it into training and validation sets.
During training, validation accuracy is monitored, and the best model is saved. The final accuracy will be printed at the end of training.